nn.Module is the base class for every model, layer, and sub-block in PyTorch. A module is a Python class that owns its parameters, owns its sub-modules, knows how to move itself to a device, knows how to serialize and deserialize itself, and provides a call interface (model(x)) that runs the forward pass. Understanding the conventions that nn.Module enforces, and the registration mechanisms it relies on, is what separates a working PyTorch developer from one who fights the framework.
The Canonical Pattern
A custom module subclasses nn.Module, calls super().__init__() first thing in its constructor, declares its layers as attributes, and implements a forward(self, x) method that defines the computation. Calling the instance dispatches to forward through Python's __call__ protocol, but the call also runs registered pre- and post-hooks and bookkeeping that a direct self.forward(x) bypasses; always use model(x), never model.forward(x).
import torch
import torch.nn as nn
class NeuralNetwork(nn.Module):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(num_inputs, 30),
nn.ReLU(),
nn.Linear(30, 20),
nn.ReLU(),
nn.Linear(20, num_outputs),
)
def forward(self, x):
# Outputs of the last linear layer are called logits.
return self.layers(x)
model = NeuralNetwork(50, 3)
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {n_params}")
print(model) produces a free structured architecture summary; the parameter count idiom is a one-liner that should be in every training script.Parameters vs Buffers
nn.Module distinguishes between two kinds of tensor-valued state. Parameters are tensors that participate in training: they are returned by model.parameters(), are passed to the optimizer, have requires_grad=True by default, and are included in state_dict() for checkpointing. Buffers are tensors that belong to the module but should not be trained: running mean and variance in batch norm, position embeddings in some transformer variants, lookup tables, learned-and-then-frozen statistics. Buffers do not appear in model.parameters() but do appear in state_dict().
The registration mechanism for each is explicit. A tensor assigned to self becomes a parameter only if wrapped in nn.Parameter(). A tensor that should be a buffer is registered via self.register_buffer(name, tensor); thereafter self.name resolves to the buffer and it is automatically moved with the module on .to(device). Raw self.foo = tensor without either wrapper creates a regular Python attribute that is not moved with the module and is not included in the state dict; this is almost always a bug.
import torch
import torch.nn as nn
class RunningMeanLayer(nn.Module):
def __init__(self, dim):
super().__init__()
# Learnable scale, registered as a parameter.
self.scale = nn.Parameter(torch.ones(dim))
# Frozen statistic, registered as a buffer.
self.register_buffer("running_mean", torch.zeros(dim))
def forward(self, x):
return self.scale * (x - self.running_mean)
layer = RunningMeanLayer(4).to("cuda" if torch.cuda.is_available() else "cpu")
print("Parameters:", list(n for n, _ in layer.named_parameters()))
print("Buffers: ", list(n for n, _ in layer.named_buffers()))
print("State dict keys:", list(layer.state_dict().keys()))
state_dict, but only the parameter is updated by the optimizer.Assigning a tensor directly to self.foo = torch.zeros(...) creates a regular Python attribute. The tensor will not move to the GPU when model.to("cuda") is called, will not be saved in state_dict(), and will not be restored when loading a checkpoint. The first sign of trouble is usually a cryptic "expected all tensors to be on the same device" error inside the forward pass. The fix is always one of: wrap in nn.Parameter (if it should be trained), register as a buffer (if it should be saved but not trained), or compute on the fly from device-aware inputs.
state_dict and Checkpointing
A module's state_dict() is an ordered dictionary mapping qualified names (like "layers.0.weight") to tensors. It contains every parameter and every buffer in the module and its descendants. Saving and loading a model is the two-line idiom torch.save(model.state_dict(), path) followed by model.load_state_dict(torch.load(path)). The architecture must be rebuilt before loading; the state dict only carries values, not topology. This is the portable convention and the one to prefer over pickling the entire model object, which couples the saved file to the exact class definition and import path.
load_state_dict() has a strict= argument that defaults to True. With strict loading, every key in the file must match a key in the model, and vice versa; missing or extra keys raise an error. Setting strict=False is useful when loading a pretrained backbone into a model with new task-specific heads: the heads will be flagged as "missing" but loading proceeds, leaving them at their freshly initialized values. The function returns a named tuple (missing_keys, unexpected_keys) so the diff can be inspected.
Freezing Layers
To freeze a layer (exclude it from optimization), set requires_grad=False on each of its parameters and pass only the still-trainable parameters to the optimizer. The first step prevents gradient computation; the second prevents updates. Doing only one of them creates subtle bugs: parameters with requires_grad=False still in the optimizer get no gradient and therefore no update (silently wasted optimizer state), and parameters with requires_grad=True not in the optimizer accumulate gradients but never move.
import torch.nn as nn
import torch.optim as optim
backbone = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.Linear(768, 768))
head = nn.Linear(768, 10)
model = nn.Sequential(backbone, head)
# Freeze backbone.
for p in backbone.parameters():
p.requires_grad = False
# Optimizer only sees parameters that will actually move.
trainable = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(trainable, lr=1e-4)
print(f"Total params: {sum(p.numel() for p in model.parameters())}")
print(f"Trainable: {sum(p.numel() for p in trainable)}")
For modern fine-tuning (LoRA, adapters, prefix tuning), do not freeze layers by hand. Hugging Face's peft library wraps the base model, freezes its weights, and injects trainable adapter parameters in one call. The result is a model where model.print_trainable_parameters() shows the right counts and the optimizer only sees the adapter weights.
from peft import LoraConfig, get_peft_model
cfg = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
model = get_peft_model(base_model, cfg)
model.print_trainable_parameters()
# trainable params: 4,194,304 || all params: 7,000,000,000 || trainable%: 0.06
Sequential, ModuleList, and ModuleDict
Three built-in containers organize sub-modules in ways that register them correctly. nn.Sequential is the simplest: it stores modules in order and its forward chains them. Use it for straight-line stacks like the MLP above. nn.ModuleList stores modules in a Python-list-like container without defining a forward; use it when the loop or branching logic in the forward pass is non-trivial. nn.ModuleDict is the analogous dictionary container; use it when modules are keyed by names that matter at runtime (for example, multiple task heads keyed by task name).
The reason these containers exist is registration. A plain Python list of modules (self.layers = [nn.Linear(...), nn.Linear(...)]) is not registered: the parameters do not appear in model.parameters(), the modules are not moved with .to(device), and the state dict misses them. ModuleList and ModuleDict are drop-in replacements that solve this.
import torch.nn as nn
class MultiTaskHead(nn.Module):
def __init__(self, dim, task_dims):
super().__init__()
# ModuleDict because heads are addressed by task name at runtime.
self.heads = nn.ModuleDict({
name: nn.Linear(dim, out) for name, out in task_dims.items()
})
def forward(self, features, task_name):
return self.heads[task_name](features)
model = MultiTaskHead(128, {"sentiment": 3, "topic": 10, "stance": 4})
for name, p in model.named_parameters():
print(name, tuple(p.shape))
ModuleDict for runtime-addressable sub-modules. All three task heads are correctly registered and would move with a single model.to(device) call.Custom Weight Initialization
Layer constructors apply sensible default initialization (Kaiming uniform for linear weights, zeros for biases, and so on), but research often calls for something different: Xavier normal for tanh networks, scaled normal for residual blocks, fan-in scaling for attention. The pattern is to write a single-parameter function that takes a module, inspects its type, and reinitializes its parameters; then call model.apply(init_fn). The apply method walks the entire module tree and invokes the function on every sub-module, leaves included.
import torch.nn as nn
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
model = nn.Sequential(nn.Linear(10, 20), nn.LayerNorm(20), nn.Linear(20, 5))
model.apply(init_weights)
model.apply. The function is responsible for branching on module type; modules it does not recognize are left alone.Weight Tying
Weight tying makes two modules share the same parameter tensor. The classic case is a language model where the input token embedding and the output projection are tied: both are matrices of shape $(V, D)$, and tying halves the parameter count of the largest two layers while often improving perplexity. The implementation is a single assignment: after both modules are constructed, do output_projection.weight = embedding.weight. Because nn.Parameter is a tensor subclass, the assignment shares the underlying storage, and both modules will see updates to either through the optimizer.
Tied weights are the same tensor. There is no synchronization step, no need to copy, and no risk of drift. The optimizer sees the parameter once (it does not appear twice in model.parameters() because the underlying Python id is the same). The state dict, however, may save the tensor under both names; loading then writes the same data to both keys, which is harmless. This trick generalizes beyond embeddings: any time two parts of a model should be mathematically identical, tie them rather than copying.
An nn.Module is a registered container. Tensors that should be trained go in nn.Parameter; tensors that should travel with the module but not be trained go in buffers; sub-modules go in Sequential, ModuleList, or ModuleDict. The single decision that most often goes wrong is using a plain Python attribute (or list) where one of these registration mechanisms is needed; the resulting bugs are silent until something tries to .to(device) or save a checkpoint. The discipline that pays off is to ask, for every self.foo = ... in a constructor, which registration category it belongs to.
Objective. Recognize which of the three registration categories each tensor in a module belongs to.
Task. Build a BatchNorm-Lite module with the following state: a learnable scale gamma, a learnable shift beta, a running mean and running variance updated each forward pass with momentum 0.1, and a constant eps = 1e-5. Classify each as nn.Parameter, register_buffer, or plain Python attribute. Implement it, then call .to("cuda") and confirm with tensor.device on each that the right tensors moved.
Hint. Running statistics must travel with the module across .to() and across state_dict save/load, but they are not updated by the optimizer.
Objective. Practice partial training, the foundation of any fine-tuning workflow.
Task. Load a pretrained torchvision.models.resnet18(weights="DEFAULT"). Replace model.fc with a new nn.Linear(512, 10). Freeze every parameter outside the new head, then verify three things: (1) sum(p.numel() for p in model.parameters() if p.requires_grad) equals the head parameter count; (2) the optimizer receives only the head parameters; (3) a single training step changes the head weights but leaves a sampled encoder weight bit-identical.
Stretch. Switch every BatchNorm2d in the encoder to eval mode and explain in one comment why this matters for fine-tuning on a small dataset.
Further Reading
Module Documentation
Linear to MultiheadAttention. Skim once to know what already exists before writing a new layer from scratch.