Autograd is the engine that makes PyTorch a deep learning framework rather than a fast tensor library. Every operation performed on a tensor that requires gradients is recorded in a dynamic computational graph; calling loss.backward() then walks that graph in reverse, applying the chain rule to compute gradients of the loss with respect to every leaf parameter. Without autograd, every gradient in every model would have to be derived and coded by hand; with autograd, the framework does that work automatically, every time.
This section explains how the engine works under the hood, when to use the various ways of controlling it, and how to extend it with custom operations and hooks. The mathematical justification for the chain rule itself is covered in Appendix A.3 (Calculus for Machine Learning); here the focus is on the mechanics PyTorch exposes.
The Computational Graph
A tensor that participates in a differentiable computation carries two pieces of extra state. The first is the boolean flag requires_grad; if true, autograd will track operations on this tensor. The second is the grad_fn attribute, which points to a node in the dynamic graph that records the operation that produced this tensor and holds references to the upstream tensors needed to compute the local gradient.
Leaf tensors (model parameters, typically) have requires_grad=True and grad_fn=None because no operation produced them. Intermediate tensors have requires_grad=True only if at least one of their inputs did, and their grad_fn records the producing op (for example AddmmBackward0 for the result of a linear layer). The graph is built on the fly during the forward pass and discarded after each backward pass; that is what "dynamic computation graph" means.
import torch
w = torch.tensor([1.5], requires_grad=True) # leaf parameter
b = torch.tensor([0.3], requires_grad=True) # leaf parameter
x = torch.tensor([2.0]) # leaf data, no grad
z = x * w + b # intermediate
a = torch.sigmoid(z) # intermediate
loss = (a - 1.0) ** 2 # scalar output
print("w.grad_fn:", w.grad_fn)
print("z.grad_fn:", z.grad_fn)
print("a.grad_fn:", a.grad_fn)
print("loss.grad_fn:", loss.grad_fn)
grad_fn chain reveals the operations that built each tensor. Leaf parameters have no grad_fn; intermediates record the producing op so the backward pass can replay it.backward() and Gradient Accumulation
Calling loss.backward() traverses the graph from loss back to every leaf with requires_grad=True, computing partial derivatives by the chain rule and writing them into the leaf's .grad attribute. A scalar loss requires no arguments. A non-scalar output requires a gradient= argument of the same shape, which is multiplied element-wise with the implicit upstream gradient (typically used for advanced cases like Jacobian-vector products).
A subtle and important behavior: .grad accumulates. Each call to backward() adds the new gradient to whatever is already in .grad. This is by design; it enables gradient accumulation across multiple forward passes (useful for simulating a larger batch). It also means that in a normal training loop, optimizer.zero_grad() must be called before every backward pass to clear stale gradients. Forgetting this is the most common training-loop bug.
Since PyTorch 1.7, optimizer.zero_grad(set_to_none=True) (the default in newer versions) sets each gradient to None instead of zeroing it in place. This is faster because it skips the memory write and lets the next backward pass allocate fresh storage. The semantics are identical from the loss's perspective: an uninitialized gradient is treated as zero. However, code that reads param.grad directly (e.g., custom logging) must handle the None case explicitly. If a third-party library depends on param.grad being a zero tensor, fall back to set_to_none=False.
The other way to compute gradients is torch.autograd.grad(outputs, inputs), which returns the gradients directly without writing them to .grad. This is useful when computing gradients with respect to specific tensors (rather than all parameters) or when implementing meta-learning, second-order methods, or implicit-differentiation tricks. By default grad() consumes the graph; pass retain_graph=True to keep it for additional gradient calls, and create_graph=True to enable differentiation of the gradient itself (for Hessians and other higher-order quantities).
Controlling Gradient Tracking
Gradient tracking costs both memory (intermediate activations must be retained for the backward pass) and compute (every op records its backward function). Disabling tracking when it is not needed is a major optimization, especially during inference and validation. PyTorch offers three mechanisms for this, each with a different scope.
The first is tensor.detach(): a method that returns a new tensor sharing storage with the original but with requires_grad=False and no grad_fn. The original is unchanged. detach() is the surgical tool: it removes one specific tensor from the graph. The canonical use is for logging, where loss.detach() records the scalar value without retaining the backward graph.
The second is the context manager torch.no_grad(): every tensor produced inside the with block has requires_grad=False regardless of its inputs. This is the standard wrapper for inference and validation passes; it cuts memory use roughly in half because no activations are kept around.
import torch
model = torch.nn.Linear(50, 3)
X = torch.rand(8, 50)
# Inference: do not build the graph.
model.eval()
with torch.no_grad():
logits = model(X)
probs = torch.softmax(logits, dim=1)
print(logits.requires_grad, logits.grad_fn)
torch.no_grad() for inference. The output tensor has no grad_fn and no autograd machinery was built behind it, freeing both memory and compute.The third is torch.set_grad_enabled(bool): a context manager (or function) that toggles tracking globally inside its scope. The form with torch.set_grad_enabled(is_train): is useful inside a function that handles both training and evaluation, so the same loop body can be reused for both.
A related tool is torch.inference_mode(), a newer and stricter version of no_grad() that disables view tracking and a few other autograd bookkeeping mechanisms for extra speed. Use it for pure inference workloads where no tensor produced inside the block will ever participate in training; otherwise stick with no_grad().
Custom Autograd Functions
The full power of autograd is exposed through torch.autograd.Function, the base class that defines what it means to be a differentiable operation. Subclassing it requires implementing two static methods: forward(ctx, *inputs) computes the forward output and stashes anything needed for the backward pass on ctx; backward(ctx, *grad_outputs) retrieves the saved tensors and computes the gradients of each input.
import torch
class ScaledReLU(torch.autograd.Function):
"""ReLU multiplied by a scalar `alpha`, with custom backward."""
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x)
ctx.alpha = alpha
return alpha * torch.relu(x)
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
grad_input = grad_output * ctx.alpha * (x > 0).to(grad_output.dtype)
# Inputs: x, alpha. alpha is not a Tensor, so return None for it.
return grad_input, None
x = torch.randn(4, requires_grad=True)
y = ScaledReLU.apply(x, 2.0).sum()
y.backward()
print(x.grad)
The two situations that justify a custom Function are: (1) the forward op involves a non-PyTorch call (a C++ extension, a black-box library, a custom CUDA kernel) whose gradient PyTorch cannot infer; and (2) the natural-form gradient is numerically much more stable than what autograd would compute symbolically (the classic example is the gradient of log(1 + exp(x)), which collapses to sigmoid(x) rather than a chain of unstable intermediates). Outside these cases, let PyTorch derive the gradient; it is almost always faster to develop and rarely slower to run.
When the goal is per-sample gradients, Jacobians, or vmapped function transforms, skip the hand-rolled autograd.grad calls and reach for torch.func (the in-tree successor to functorch). The transforms grad, jacrev, jacfwd, hessian, and vmap compose freely and avoid the bookkeeping of create_graph=True recursion.
from torch.func import grad, vmap
def loss_fn(params, x, y):
return ((params @ x - y) ** 2).sum()
# Per-sample gradients: vmap the single-sample grad over the batch axis.
per_sample_grad = vmap(grad(loss_fn), in_dims=(None, 0, 0))(params, X, y)
Hooks: Inspecting and Modifying the Graph
Hooks are callbacks that fire during the forward or backward pass without changing the model's source code. They are the diagnostic tool of choice for inspecting activations, gradients, and parameter updates in a running model. Three types exist on nn.Module: register_forward_pre_hook (fires before forward), register_forward_hook (fires after forward, sees inputs and outputs), and register_full_backward_hook (fires during backward, sees gradient inputs and outputs). Each returns a handle whose .remove() method detaches the hook.
Tensor-level hooks are even finer-grained. tensor.register_hook(fn) calls fn(grad) every time a gradient flows through that tensor and lets fn return a modified gradient (or None to leave it untouched). This is the right level for gradient clipping or noise injection on a specific intermediate.
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 4))
activations = {}
def save_act(name):
def hook(module, inputs, output):
activations[name] = output.detach()
return hook
handles = [
model[0].register_forward_hook(save_act("linear_0")),
model[2].register_forward_hook(save_act("linear_1")),
]
_ = model(torch.randn(2, 8))
for k, v in activations.items():
print(k, tuple(v.shape))
for h in handles:
h.remove()
Saving output in a closure captures the entire computation graph that produced it. If the saved value is not detached (or wrapped in no_grad() upstream), the hook will retain activations across iterations and the training loop will OOM after a few steps. Always store output.detach() or output.detach().cpu() inside a hook. Similarly, register and remove hooks within the same logical scope; forgotten hooks accumulate.
Every batch, the graph is rebuilt from scratch and demolished after backward. PyTorch is, in effect, a model with severe short-term memory loss: it remembers only the current forward pass, and the moment you call backward() it cheerfully forgets everything. This is a feature (dynamic shapes! Python control flow!), not a bug, but it does explain why retain_graph=True feels like asking the framework to take notes.
Anomaly Detection
When a backward pass produces NaN or Inf gradients, the most useful debugging aid is torch.autograd.set_detect_anomaly(True). With anomaly detection on, PyTorch records the Python stack trace of every forward op and replays it whenever a non-finite gradient is detected during the backward pass, so the offending operation can be identified precisely. The cost is significant slowdown, so this is a debugging tool rather than a production setting. Section E.9 walks through the full NaN-hunting workflow.
Autograd is a record-and-replay engine: every operation on a tensor with requires_grad=True appends a node to a dynamic graph, and backward() replays the graph in reverse using the chain rule. The three tools for controlling the graph (detach, no_grad, set_grad_enabled) are essential for keeping memory in check during inference and logging. Custom Function subclasses handle the rare cases where autograd cannot derive a gradient, and hooks are the universal mechanism for inspecting what is flowing through a model without modifying its source.
Objective. Confirm that autograd matches a closed-form derivative.
Task. Define f(x) = (x ** 3 - 2 * x).sum(). For x = torch.randn(8, requires_grad=True), compute the gradient via backward(), then compare against the analytic gradient 3 * x.detach() ** 2 - 2 using torch.allclose. Repeat for f(x) = torch.sigmoid(x).sum(); the analytic gradient is sig * (1 - sig).
Stretch. Use torch.autograd.gradcheck with double precision to validate a small custom function against numerical differences. Read the docstring for the eps and atol arguments.
Objective. Practice the three primary tools for controlling the autograd graph.
Task. Build a 3-layer MLP and a tiny training step. Then: (a) wrap the validation forward pass in torch.no_grad() and confirm with torch.cuda.memory_allocated() that activation memory does not grow; (b) register a backward hook on the first linear layer that prints the gradient's L2 norm; (c) detach() the output of layer two before passing it to layer three and verify that layer-one parameters receive zero gradient.
Expected outcome. The hook fires once per backward call, prints a positive scalar in the normal case, and prints zero for the layer that is starved of gradient by the detach.
Further Reading
Autograd Documentation and Theory
autograd.Function subclasses, including the subtle rules about save_for_backward and double-backward support.