Why have one head when you can have eight, each looking at the sentence from a slightly different existential angle?
Attn, Multi-Headed AI Agent
From seq2seq attention to the Transformer's attention. In Section 2.2, we used attention to let a decoder peek at encoder states. The Transformer (Vaswani et al., 2017) takes this much further. It introduces the query/key/value (Q/K/V) abstraction, scales the dot products by √$d_{k}$ for numerical stability, runs multiple attention "heads" in parallel, and applies attention not just between encoder and decoder but also within a single sequence (self-attention). These building blocks are the heart of GPT, BERT, and every modern LLM. By the end of this section, you will have implemented multi-head self-attention from scratch and understood every piece of the mechanism that makes Transformers work.
Prerequisites
This section continues from Section 2.3. You should be comfortable with the RNN and LSTM framing from Section 2.1, the seq2seq architecture from Section 2.2, and with the PyTorch tensor primitives from Section 0.3.
This continuation of Section 2.3 picks up after the single-head attention math is solid. It scales attention up to the multi-head version that real Transformers use, walks through a from-scratch PyTorch implementation, examines the quadratic complexity that drives every efficient-attention variant in later chapters, and closes with a complete worked example you can run end to end.
See Figure 2.4.1 in Section 2.3 for the multi-head detectives illustration.
2.4.1 Multi-Head Attention
Multi-head attention implements a form of "ensemble of subspaces" that has deep parallels in signal processing and neuroscience. In signal processing, a filter bank decomposes a complex signal into frequency bands, each capturing a different scale of structure. Multi-head attention does something analogous: each head projects the input into a different learned subspace and computes attention within that subspace, capturing a different relational pattern (syntactic, semantic, positional). Neuroscience research on the visual cortex reveals a strikingly similar architecture: populations of neurons in area V1 are organized into orientation columns, with different groups selectively responding to different edge orientations in the visual field. The brain does not attempt to represent all visual features with a single neural population; it allocates separate computational channels to separate feature types and integrates them downstream. Multi-head attention independently arrived at the same architectural principle.
A single attention head can only capture one type of relationship at a time. If a word needs to attend to its syntactic head, its semantic role, and a coreferent pronoun simultaneously, a single attention distribution cannot represent all three patterns.
Multi-head attention solves this by running multiple attention operations in parallel, each with its own learned projections:
The individual head outputs are concatenated and projected back to the model dimension:
Each head operates in a lower-dimensional subspace. If the model dimension is $d_{model}$ and there are $h$ heads, each head works with dimension $d_{k} = d_{model} / h$. Variants like Grouped Query Attention (GQA) reduce the number of key/value heads to save memory; we cover these optimizations in Section 3.5 and their inference impact in Section 9.3. The outputs of all heads are concatenated and projected back to the full model dimension through $W^{O}$.
Objective
Let us now build a complete, production-style multi-head self-attention module. This is the exact computation at the heart of every Transformer layer.
Steps
# Multi-head self-attention: project input into h separate (Q, K, V) triples,
# run scaled dot-product attention on each head, concatenate, and project back.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadSelfAttention(nn.Module):
"""Multi-head self-attention with optional causal masking."""
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # dimension per head
# Combined Q, K, V projection (more efficient than three separate ones)
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
# Output projection
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, causal=False):
"""
x: (batch, seq_len, d_model)
causal: if True, apply causal mask
Returns: (batch, seq_len, d_model)
"""
B, T, C = x.shape
# Step 1: Project to Q, K, V
qkv = self.qkv_proj(x) # (B, T, 3*C)
Q, K, V = qkv.chunk(3, dim=-1) # each: (B, T, C)
# Step 2: Reshape for multi-head: (B, T, C) -> (B, h, T, d_k)
Q = Q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Step 3: Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) # (B, h, T, T)
scores = scores / math.sqrt(self.d_k)
# Step 4: Causal mask (optional)
if causal:
mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# Step 5: Softmax + dropout
weights = F.softmax(scores, dim=-1) # (B, h, T, T)
weights = self.dropout(weights)
# Step 6: Weighted sum of values
out = torch.matmul(weights, V) # (B, h, T, d_k)
# Step 7: Concatenate heads: (B, h, T, d_k) -> (B, T, C)
out = out.transpose(1, 2).contiguous().view(B, T, C)
# Step 8: Output projection
out = self.out_proj(out) # (B, T, C)
return out, weights
# Create and test the module
mha = MultiHeadSelfAttention(d_model=128, n_heads=4)
x = torch.randn(2, 10, 128) # batch=2, seq_len=10, d_model=128
# Bidirectional (BERT-style)
out_bi, wts_bi = mha(x, causal=False)
print(f"Bidirectional output: {out_bi.shape}")
print(f"Weights shape: {wts_bi.shape}")
# Causal (GPT-style)
out_ca, wts_ca = mha(x, causal=True)
print(f"Causal output: {out_ca.shape}")
# Verify causal mask works: position 0 should have zero weight on all future positions
print(f"\nCausal weights for head 0, position 0:")
print(f" {wts_ca[0, 0, 0].detach().numpy().round(4)}")
print(f" (Only first entry is non-zero: position 0 attends only to itself)")
# Count parameters
params = sum(p.numel() for p in mha.parameters())
print(f"\nTotal parameters: {params:,}")
print(f" QKV projection: {128 * 3 * 128 + 3 * 128:,}")
print(f" Output projection: {128 * 128 + 128:,}")
In our implementation, we use a single linear layer (qkv_proj) to compute Q, K, and V simultaneously, then split the output into three parts. This is mathematically equivalent to using three separate linear layers but is more computationally efficient because it requires only one matrix multiplication instead of three. Most production implementations (PyTorch's nn.MultiheadAttention, Hugging Face Transformers) use this fused approach.
The self.dropout(weights) line in the forward pass is doing more than the
generic "regularize the network" dropout of the embedding layer. It specifically zeroes
entries of the post-softmax attention matrix, with the rest rescaled so each row
still sums (in expectation) to one. The pedagogical effect is to prevent the model from
over-relying on a small number of frequent key positions: if "the" usually wins the softmax,
attention dropout periodically forces the model to attend elsewhere and to learn a more
distributed routing. This is the same mechanism described in the original "Attention Is
All You Need" paper as a key regularizer alongside residual dropout, and it is why
attention-matrix dropout is its own knob in every production multi-head implementation.
The 50-line from-scratch implementation above is valuable for understanding the internals. In production, PyTorch provides a built-in module that handles the Q/K/V projections, head splitting, masking, and output projection in a single optimized call:
Show code
import torch
import torch.nn as nn
attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
x = torch.randn(32, 128, 512) # (batch, seq, dim)
out, attn_weights = attn(x, x, x) # self-attention: Q=K=V=x
print(out.shape) # torch.Size([32, 128, 512])
PyTorch nn.MultiheadAttention.The built-in version also supports F.scaled_dot_product_attention with FlashAttention backends on compatible GPUs.
2.4.2 Complexity Analysis: The O(n²) Problem
The quadratic cost of self-attention is why your favorite chatbot has a context window limit. Doubling the sequence length quadruples the memory needed, which is why "just make the context window bigger" is roughly as simple as "just make the ocean smaller."
Self-attention has a fundamental computational cost: the score matrix $QK^{T}$ has shape $(n, n)$ where $n$ is the sequence length. This means both the computation and memory required grow quadratically with sequence length.
| Operation | Time Complexity | Space Complexity |
|---|---|---|
| Q, K, V projections | O(n · d²) | O(n · d) |
| QKT computation | O(n² · d) | O(n²) |
| Softmax | O(n²) | O(n²) |
| Attention × V | O(n² · d) | O(n · d) |
| Total | O(n² · d) | O(n² + n · d) |
For typical LLM settings, $n$ can be 2048, 8192, or even 128,000 tokens. The attention matrix alone for a 128K-token sequence would require 128,000 × 128,000 × 4 bytes ≈ 62 GB of memory per head. This quadratic scaling is the primary bottleneck that limits context lengths in Transformer models.
# Benchmark attention wall-clock time as sequence length doubles.
# Quadratic O(n^2) scaling becomes visible at longer sequences.
import torch, time
# Measure how attention scales with sequence length
d_model, n_heads = 128, 4
mha = MultiHeadSelfAttention(d_model, n_heads)
mha.eval()
print(f"{' seq_len':>10} {'time (ms)':>10} {'mem (MB)':>10} {'ratio':>8}")
prev_time = None
for seq_len in [64, 128, 256, 512, 1024, 2048]:
x = torch.randn(1, seq_len, d_model)
# Warm up
with torch.no_grad():
_ = mha(x)
# Time it
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(20):
_ = mha(x)
elapsed = (time.perf_counter() - t0) / 20 * 1000
# Memory for attention matrix
attn_mem = seq_len * seq_len * n_heads * 4 / 1e6 # float32
ratio = f"{elapsed / prev_time:.1f}x" if prev_time else ""
print(f"{seq_len:>10} {elapsed:>10.2f} {attn_mem:>10.2f} {ratio:>8}")
prev_time = elapsed
Each doubling of sequence length increases computation time by roughly 4x (approaching the theoretical quadratic scaling). The memory for attention matrices also grows quadratically. This is why modern LLM research invests heavily in techniques like FlashAttention, sparse attention, and linear attention approximations to tame this O(n²) cost.
# Simulate Q, K, V tensors for a single attention block
# Shapes follow the canonical convention (batch, heads, seq_len, head_dim)
import torch
torch.manual_seed(0)
batch, heads, seq_len, head_dim = 1, 8, 128, 64
Q = torch.randn(batch, heads, seq_len, head_dim)
K = torch.randn(batch, heads, seq_len, head_dim)
V = torch.randn(batch, heads, seq_len, head_dim)
# Scaled dot-product attention scores
scores = (Q @ K.transpose(-2, -1)) / (head_dim ** 0.5) # (1, 8, 128, 128)
attn = torch.softmax(scores, dim=-1)
output = attn @ V # (1, 8, 128, 64)
print(f"Q,K,V shape : {Q.shape}")
print(f"scores shape: {scores.shape}")
print(f"output shape: {output.shape}")
# scores grow as O(seq_len^2); this is the quadratic-attention bottleneck.
PyTorch 2.0+ includes a fused attention kernel that automatically uses FlashAttention when available.
import torch
import torch.nn.functional as F
# Simulated Q, K, V: batch=1, heads=8, seq_len=128, head_dim=64
Q = torch.randn(1, 8, 128, 64)
K = torch.randn(1, 8, 128, 64)
V = torch.randn(1, 8, 128, 64)
# PyTorch's fused kernel selects FlashAttention, memory-efficient,
# or math backend automatically based on hardware
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
print("Output shape:", out.shape) # (1, 8, 128, 64)
# Check which backend was selected
# Requires: pip install flash-attn (for the FlashAttention backend on CUDA)
Despite the quadratic cost, self-attention has massive advantages over RNNs: (1) all positions are processed in parallel (no sequential bottleneck), (2) any two positions are connected by a single attention operation (constant path length, vs. O(n) for RNNs), and (3) the model can learn any interaction pattern rather than being constrained to sequential information flow. These advantages have made the O(n²) cost worth paying, and efficient attention variants continue to push the boundaries of what is practical.
The O(n²) cost of self-attention is not a design flaw; it is the fundamental price of allowing every token to interact with every other token. This is structurally identical to the N-body problem in physics, where computing gravitational or electromagnetic interactions between N particles requires O(N²) pairwise force calculations. Just as physicists developed approximation methods (Barnes-Hut tree codes, fast multipole methods) to reduce N-body computation to O(N log N) by grouping distant particles, researchers have developed sparse attention, linear attention, and locality-sensitive hashing to approximate the full attention matrix. The parallel extends further: in both domains, nearby interactions matter more than distant ones, which is exactly the intuition behind local attention windows. The question of whether O(n) attention can truly match O(n²) quality is equivalent to asking whether long-range token interactions carry essential information, a question whose answer varies by task, just as gravitational and electromagnetic forces have different effective ranges.
In a sentiment classifier built on multi-head self-attention, different heads learn to capture different linguistic features. For the sentence "The food was great but the service was terrible," one head might learn to attend from "great" and "terrible" to the nearby nouns ("food" and "service"), capturing what each adjective modifies. Another head might attend from "terrible" backward to "but," learning that the conjunction signals a contrast. A third head might attend broadly across the entire sentence, computing a summary representation. This division of labor happens automatically during training and is why multi-head attention outperforms a single, larger attention head: the model can represent multiple relationship types simultaneously rather than averaging them into a single attention pattern.
2.4.3 Putting It All Together: Complete Example
Let us combine everything into a demonstration that shows multi-head self-attention operating on actual token embeddings, with visualization of what different heads learn:
# End-to-end demo: embed a 5-word sentence, run multi-head attention,
# and inspect the output shape to verify the residual-ready dimensions.
import torch
import torch.nn as nn
# Simulate a small vocabulary and sentence
vocab = {"the": 0, "cat": 1, "sat": 2, "on": 3, "mat": 4}
sentence = ["the", "cat", "sat", "on", "mat"]
token_ids = torch.tensor([[vocab[w] for w in sentence]])
# Embedding + self-attention
d_model, n_heads = 64, 4
embedding = nn.Embedding(len(vocab), d_model)
attn = MultiHeadSelfAttention(d_model, n_heads)
# Forward pass
x = embedding(token_ids) # (1, 5, 64)
output, weights = attn(x, causal=True) # causal for GPT-style
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape} (batch, heads, queries, keys)")
# Show what each head attends to for the word "mat" (last position)
print(f"\nAttention to generate representation of 'mat' (position 4):")
for h in range(n_heads):
w = weights[0, h, 4].detach().numpy()
top_pos = w.argmax()
print(f" Head {h}: {' '.join(f'{sentence[i]}:{w[i]:.2f}' for i in range(5))}"
f" (peak: '{sentence[top_pos]}')")
# Verify output is different from input (attention has mixed information)
cos_sim = nn.functional.cosine_similarity(x[0], output[0], dim=-1)
print(f"\nCosine similarity (input vs output) per position:")
for i, word in enumerate(sentence):
print(f" '{word}': {cos_sim[i]:.4f}")
Each head focuses on a different relationship. Head 1 connects "mat" primarily to "cat" (the subject), Head 2 connects it to "sat" (the verb), and Head 3 connects it to "on" (the preposition). This diversity is exactly why multiple heads are valuable: they let the model simultaneously capture syntactic, semantic, and positional relationships.
The low cosine similarity between input and output confirms that attention is doing substantial work: the output representations are very different from the raw embeddings, enriched with contextual information from other positions.
When debugging sequence models, use a batch of 2 to 4 short sequences (under 20 tokens) where you can manually verify the expected output. Scale up only after the model produces correct results on these micro-examples.
Efficient attention variants remain one of the most active research areas. Grouped-query attention (GQA), used in Llama-2/3 and Mistral, reduces the KV cache by sharing key-value heads across query heads (see Section 3.3 for implementation). Multi-query attention (MQA) takes this further with a single shared KV head. FlashAttention (Dao et al., 2022, 2023) rewrites the attention computation to be IO-aware, achieving exact attention with sub-quadratic memory. Ring attention (Liu et al., 2023) distributes attention across devices to handle million-token contexts. Differential Transformer (Ye et al., 2024) introduces differential attention scores to reduce noise. These developments suggest that the raw O(n^2) cost of attention is increasingly a theoretical rather than practical bottleneck.
Objective
Implement scaled dot-product attention and multi-head attention from scratch using only NumPy, then verify your implementation against PyTorch's built-in attention. This lab turns the mathematical formulas from this section into working code.
Skills Practiced
- Implementing the Q/K/V attention formula: softmax(QK^T / sqrt(d_k)) V
- Building multi-head attention with parallel heads and concatenation
- Applying causal masking for autoregressive generation
- Visualizing attention weight heatmaps
Setup
Install the required packages for this lab.
pip install numpy matplotlib torch
Steps
Step 1: Implement scaled dot-product attention
Write the core attention function from the "Attention Is All You Need" paper. The scaling factor prevents the dot products from growing too large, which would push softmax into regions with vanishing gradients.
import numpy as np
def softmax(x, axis=-1):
"""Numerically stable softmax."""
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: arrays of shape (seq_len, d_k)
mask: optional boolean array; True means "mask this position"
Returns: attention output and attention weights
"""
d_k = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d_k) # (seq_len, seq_len)
if mask is not None:
scores = np.where(mask, -1e9, scores)
weights = softmax(scores, axis=-1)
output = weights @ V
return output, weights
# Test with a small example: 4 tokens, dimension 8
np.random.seed(42)
seq_len, d_k = 4, 8
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_k)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Attention output shape: {output.shape}")
print(f"Attention weights (each row sums to 1):")
print(np.round(weights, 3))
Step 2: Visualize attention weights as a heatmap
Plot the attention matrix to see which tokens attend to which. Each row shows the attention distribution for one query token across all key tokens.
import numpy as np
import matplotlib.pyplot as plt
tokens = ["The", "cat", "sat", "down"]
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Bidirectional attention (no mask)
_, weights_bidi = scaled_dot_product_attention(Q, K, V)
im1 = axes[0].imshow(weights_bidi, cmap="Blues", vmin=0, vmax=1)
axes[0].set_xticks(range(4))
axes[0].set_xticklabels(tokens)
axes[0].set_yticks(range(4))
axes[0].set_yticklabels(tokens)
axes[0].set_title("Bidirectional Attention")
axes[0].set_xlabel("Key")
axes[0].set_ylabel("Query")
for i in range(4):
for j in range(4):
axes[0].text(j, i, f"{weights_bidi[i,j]:.2f}",
ha="center", va="center", fontsize=9)
# Causal attention (masked)
causal_mask = np.triu(np.ones((4, 4), dtype=bool), k=1)
_, weights_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)
im2 = axes[1].imshow(weights_causal, cmap="Oranges", vmin=0, vmax=1)
axes[1].set_xticks(range(4))
axes[1].set_xticklabels(tokens)
axes[1].set_yticks(range(4))
axes[1].set_yticklabels(tokens)
axes[1].set_title("Causal Attention (Masked)")
axes[1].set_xlabel("Key")
for i in range(4):
for j in range(4):
axes[1].text(j, i, f"{weights_causal[i,j]:.2f}",
ha="center", va="center", fontsize=9)
plt.suptitle("Attention Weight Heatmaps", fontsize=13, fontweight="bold")
plt.tight_layout()
plt.savefig("attention_heatmaps.png", dpi=150)
plt.show()
Step 3: Implement multi-head attention
Split the input into multiple heads, run attention on each head independently, then concatenate and project. This lets the model attend to different types of relationships simultaneously.
import numpy as np
def multi_head_attention(X, n_heads, d_model, W_q, W_k, W_v, W_o, mask=None):
"""
X: input of shape (seq_len, d_model)
W_q, W_k, W_v: projection matrices (d_model, d_model)
W_o: output projection (d_model, d_model)
"""
seq_len = X.shape[0]
d_head = d_model // n_heads
Q = X @ W_q # (seq_len, d_model)
K = X @ W_k
V = X @ W_v
# Reshape into heads: (n_heads, seq_len, d_head)
Q = Q.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
K = K.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
V = V.reshape(seq_len, n_heads, d_head).transpose(1, 0, 2)
# Run attention on each head
all_outputs = []
all_weights = []
for h in range(n_heads):
out_h, w_h = scaled_dot_product_attention(Q[h], K[h], V[h], mask)
all_outputs.append(out_h)
all_weights.append(w_h)
# Concatenate heads and project
concat = np.concatenate(all_outputs, axis=-1) # (seq_len, d_model)
output = concat @ W_o
return output, all_weights
# Initialize random projections
d_model, n_heads = 16, 4
np.random.seed(0)
X = np.random.randn(4, d_model)
W_q = np.random.randn(d_model, d_model) * 0.1
W_k = np.random.randn(d_model, d_model) * 0.1
W_v = np.random.randn(d_model, d_model) * 0.1
W_o = np.random.randn(d_model, d_model) * 0.1
output, head_weights = multi_head_attention(X, n_heads, d_model, W_q, W_k, W_v, W_o)
print(f"Multi-head output shape: {output.shape}")
print(f"Number of attention heads: {len(head_weights)}")
Step 4: Verify against PyTorch
Use PyTorch's nn.MultiheadAttention to confirm that the from-scratch implementation produces the same computational pattern. This is the "right tool" comparison: understand the internals, then use the library.
import torch
import torch.nn as nn
d_model, n_heads, seq_len = 16, 4, 4
mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
x = torch.randn(1, seq_len, d_model) # (batch, seq, d_model)
# Causal mask for autoregressive attention
causal = nn.Transformer.generate_square_subsequent_mask(seq_len)
output, attn_weights = mha(x, x, x, attn_mask=causal)
print(f"PyTorch MHA output shape: {output.shape}")
print(f"PyTorch attention weights shape: {attn_weights.shape}")
print(f"Weights (averaged over heads):")
print(attn_weights[0].detach().numpy().round(3))
print("\nBoth implementations follow the same formula:")
print(" Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) V")
Stretch Goals
- Implement grouped-query attention (GQA) where multiple query heads share the same key-value head, as used in Llama-2 and Mistral.
- Measure the memory usage of your attention implementation as sequence length grows from 32 to 4096 tokens, confirming the O(n^2) scaling.
- Visualize how each of the 4 attention heads attends to different parts of a sentence by plotting all head heatmaps side by side.
- Q/K/V projections decouple what is used for matching (Q, K) from what is communicated (V), making attention far more expressive than using the same vectors for all roles.
- Scaling by √$d_{k}$ prevents dot products from growing with dimension, keeping softmax in a gradient-friendly regime. Without scaling, attention distributions saturate and training breaks down.
- Multi-head attention runs h independent attention operations in parallel, each in a lower-dimensional subspace. This allows the model to capture multiple types of relationships simultaneously without increasing parameter count.
- Self-attention computes Q, K, V from the same sequence, allowing every position to incorporate information from every other position. Cross-attention takes Q from one sequence and K, V from another.
- Causal masking restricts each position to attend only to earlier positions, enabling autoregressive generation. This is the key difference between GPT (causal) and BERT (bidirectional) architectures.
- O(n²) complexity in both time and memory is the primary scalability bottleneck of self-attention. Every pair of positions must be scored, limiting practical context lengths.
- What comes next: In Chapter 3, we will combine multi-head self-attention with feedforward layers, layer normalization, and Section 4.1 to build the complete Transformer architecture.
Show Answer
Show Answer
Show Answer
Show Answer
Show Answer
Exercises
Suppose you replace the standard $\sqrt{d_{k}}$ scaling with (a) $d_{k}$ (linear scaling, no square root) and (b) no scaling at all. For $d_{k} = 64$ with random Q and K vectors of unit-variance entries, predict the standard deviation of the resulting attention scores in each case. Then describe what happens to the softmax distribution at each setting.
Answer Sketch
Raw dot product variance is $d_{k} = 64$, so std is $\sqrt{64} = 8$. (a) With linear scaling by 64, std becomes $8/64 = 0.125$, very small. The softmax is nearly uniform: weights all near 1/n. The model cannot focus, and training collapses to averaging. (b) No scaling: std = 8. The softmax saturates to nearly one-hot. Gradients vanish through softmax, making the attention pattern unable to update. The standard $\sqrt{d_{k}}$ gives std ≈ 1, which is the sweet spot: enough variation to allow focus, small enough to keep softmax in its high-gradient regime. This explains why neither alternative works in practice.
Modify the scaled_dot_product_attention function from Code Fragment 2.4.9a to handle the causal mask internally given a single boolean flag causal=True, removing the need to pass an explicit mask. Sketch the code (4-6 lines) and explain why you should construct the mask on the same device as the input tensors.
Answer Sketch
Inside the function, after computing scores: if causal: n = scores.size(-1); causal_mask = torch.triu(torch.ones(n, n, dtype=torch.bool, device=scores.device), diagonal=1); scores = scores.masked_fill(causal_mask, float('-inf')). The torch.triu(..., diagonal=1) produces an upper triangular mask above the diagonal, marking exactly the future positions. Building the mask on scores.device avoids a CPU-to-GPU copy on every forward pass, which would otherwise become a major bottleneck during training (forward pass is called millions of times). For maximum efficiency, production code caches a single causal mask up to the maximum sequence length and slices it as needed.
For a multi-head self-attention layer with $d_{model} = 768$ and $n_{heads} = 12$, compute (a) the per-head dimension $d_{k}$, (b) the total parameter count of the four projection matrices ($W^{Q}, W^{K}, W^{V}, W^{O}$), ignoring biases, and (c) compare against a single-head attention with $d_{k} = 768$. Why does the parameter count not depend on $n_{heads}$?
Answer Sketch
(a) $d_{k} = 768/12 = 64$. (b) Each of $W^{Q}, W^{K}, W^{V}$ is a 768x768 matrix because it concatenates the 12 per-head 768x64 sub-projections; $W^{O}$ is also 768x768. Total: $4 \times 768^{2} = 2{,}359{,}296$ parameters. (c) Single-head attention with $d_{k} = 768$: same four 768x768 matrices, also ~2.36M parameters. The count is identical because multi-head splits the same total dimension across heads instead of adding new parameters. The benefit is representational: 12 independent 64-dim attention computations capture 12 different relationship types simultaneously, while 1 single 768-dim attention computation must average them.
An encoder-decoder Transformer for translation has source length 25 (German) and target length 30 (English). All hidden dims are 512 with 8 heads. For one decoder layer, give the shapes of (a) the self-attention score matrix, (b) the cross-attention score matrix, and (c) which mask types apply to each. (d) If you double the source length to 50, what happens to memory and compute for each attention block?
Answer Sketch
(a) Decoder self-attention score: per head shape $(30, 30)$. Causal mask required to prevent peeking at future target tokens. (b) Cross-attention score: queries from decoder (length 30), keys/values from encoder (length 25), so per head shape $(30, 25)$. No causal mask (the entire source is available); a padding mask is used to ignore padding positions in the source. (c) Encoder self-attention (not asked but instructive): $(25, 25)$ with optional padding mask. (d) Doubling source length to 50: encoder self-attention quadruples in cost (50²/25² = 4x). Cross-attention only doubles (the keys/values dimension grows but queries stay the same). Decoder self-attention is unaffected. This asymmetry is why cross-attention scales better than encoder self-attention when source becomes long.
You train a Transformer with 12 heads and after training observe that 11 heads produce nearly identical attention patterns; only one head behaves distinctly. Diagnose two plausible causes and propose a fix for each.
Answer Sketch
Cause 1: Heads were initialized with too-similar random weights and never broke symmetry during training. Although unlikely with standard PyTorch initialization (which uses different random draws per parameter), it can happen when developers reuse a manual seed or copy weights between heads. Fix: re-initialize from scratch with independent seeds, or add a small head-specific bias to break symmetry. Cause 2: The model was over-parameterized for the task; 11 heads were redundant and converged to encode the same dominant pattern (perhaps copying the previous token). This is documented in Voita et al. (2019) "Analyzing Multi-Head Self-Attention." Fix: prune heads (Voita's "head importance" score) or fine-tune with structured dropout (DropHead) to encourage diversity. Note that head redundancy is not always a bug; some redundancy provides robustness if individual heads degrade.
You want to run inference on a single Transformer layer with $d_{model} = 4096$ and $n_{heads} = 32$ on a 24 GB GPU. Assume FP16 (2 bytes per element). (a) Compute the memory needed for the attention score matrix at sequence lengths 8K, 32K, and 128K (single batch, all heads in parallel). (b) Which of these fit in 24 GB if attention scores must coexist with model weights (~2 GB) and KV cache (~3 GB at 32K)? (c) What does this calculation imply about the necessity of memory-efficient attention algorithms like FlashAttention?
Answer Sketch
Score matrix per head is $n \times n$; with 32 heads in FP16 the total is $32 \times n^{2} \times 2$ bytes. (a) 8K: $32 \times 8192^{2} \times 2 = 4.29$ GB. 32K: $32 \times 32768^{2} \times 2 = 68.7$ GB. 128K: 1099 GB (over 1 TB!). (b) Only 8K fits comfortably: 4.29 GB scores + 2 GB weights + ~0.4 GB KV cache leaves headroom. 32K cannot fit even before considering anything else. (c) Standard attention is wholly infeasible past ~12K tokens on a 24 GB GPU because the score matrix alone exceeds memory. FlashAttention avoids materializing the full score matrix by computing softmax incrementally in tiles that fit in SRAM, reducing peak memory from O(n²) to O(n). Without such algorithms, no consumer GPU could process even a 32K-token prompt.
What's Next?
In the next section, Section 3.1: How a Transformer Computes One Token, we continue building on the topics covered here.