Part II: Understanding LLMs
Chapter 18: Interpretability & Mechanistic Understanding

Attention Analysis & Probing

"The question is not whether neural networks are black boxes. The question is whether we have the right flashlights."

Probe Probe, Flashlight Wielding AI Agent
An X-ray machine scanning a neural network, revealing internal structures and activation patterns
Figure 18.1.1: Interpretability is like giving your model an X-ray. You finally get to see what is going on inside all those billions of parameters.

Prerequisites

This section builds on transformer architecture from Section 04.1: Transformer Architecture Deep Dive and pre-training covered in Section 06.1: The Landmark Models.

Big Picture

Attention patterns and probing classifiers are the most accessible tools for understanding what transformers learn. Attention weights reveal which tokens the model considers when making predictions, while probing classifiers test what information (syntax, semantics, world knowledge) is encoded in hidden states at each layer. Combined with the logit lens, which projects intermediate representations into vocabulary space, these tools provide a layered view of how transformers transform input tokens into output predictions. The multi-head attention mechanism from Section 03.3 is the primary object of study here.

1. Attention Visualization

In Chapter 04, we learned that every transformer layer computes attention weights: a matrix that specifies how much each token attends to every other token (recall the query-key-value mechanism from Section 04.2). Visualizing these weights provides an immediate, intuitive window into the model's computation. However, interpreting attention requires care: attention weights do not directly indicate which tokens are "important" for the prediction (as we discuss in Section 18.4).

The following snippet demonstrates how to extract attention weights from a GPT-2 model and visualize a specific attention head as a heatmap.

# Extracting and visualizing attention patterns
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
 model_name, output_attentions=True
)

text = "The cat sat on the mat because it was tired"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
 outputs = model(**inputs)

# outputs.attentions is a tuple of (num_layers,) tensors
# Each tensor has shape (batch, num_heads, seq_len, seq_len)
attentions = outputs.attentions
print(f"Number of layers: {len(attentions)}")
print(f"Attention shape per layer: {attentions[0].shape}")

tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

def plot_attention_head(attention_matrix, tokens, layer, head):
 """Plot a single attention head as a heatmap."""
 fig, ax = plt.subplots(figsize=(8, 6))
 attn = attention_matrix[0, head].numpy() # (seq_len, seq_len)
 im = ax.imshow(attn, cmap="Blues", vmin=0, vmax=1)
 ax.set_xticks(range(len(tokens)))
 ax.set_yticks(range(len(tokens)))
 ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=8)
 ax.set_yticklabels(tokens, fontsize=8)
 ax.set_xlabel("Key (attending to)")
 ax.set_ylabel("Query (attending from)")
 ax.set_title(f"Layer {layer}, Head {head}")
 plt.colorbar(im)
 plt.tight_layout()
 return fig

# Visualize a specific head
fig = plot_attention_head(attentions[5], tokens, layer=5, head=1)
plt.savefig("attention_head_L5_H1.png", dpi=150)
Number of layers: 12 Attention shape per layer: torch.Size([1, 12, 10, 10])
Library Shortcut

BertViz provides interactive, publication-quality attention visualizations in one call:

from bertviz import head_view
from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("gpt2", output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("The cat sat on the mat", return_tensors="pt")
outputs = model(**inputs)

head_view(outputs.attentions, tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]))
# Opens an interactive HTML widget in Jupyter
Code Fragment 18.1.1: Extracting attention weights from GPT-2 and visualizing a specific attention head as a heatmap. The output_attentions=True flag instructs the model to return attention matrices for every layer.
# Production equivalent using scikit-learn
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

# X = extracted hidden states (n_tokens, hidden_dim), y = labels
probe = LogisticRegression(max_iter=1000)
scores = cross_val_score(probe, X, y, cv=5)
print(f"Probe accuracy: {scores.mean():.3f} +/- {scores.std():.3f}")
Code Fragment 18.1.20: Extracting and visualizing attention patterns

pip install bertviz

1.1 Common Attention Patterns

Research has identified several recurring attention patterns across transformer models. These patterns appear consistently regardless of model size, training data, or architecture variant, suggesting they represent fundamental computational primitives. Figure 18.1.1 catalogs the most common types.

Mental Model: The X-Ray Machine

Think of attention analysis as an X-ray machine for neural networks. Attention weights show you which input tokens the model 'looked at' when producing each output, much like an X-ray shows which bones are connected. Probing classifiers go deeper: they test whether specific information (part-of-speech, entity type) is actually stored in the hidden layers. The caution is that attention weights show correlation, not causation: a head may attend to a token without that attention being the reason for the model's output.

Fun Fact

Attention visualization looks deceptively simple: just plot which tokens attend to which. In practice, a 32-layer, 32-head transformer produces 1,024 attention matrices for a single input, which is less "insightful heatmap" and more "wall of colored noise" without careful filtering.

Figure 18.1.2: Common attention head types observed across transformer models...
Figure 18.1.2: Common attention head types observed across transformer models. Each pattern represents a distinct computational role.

The following table summarizes these common attention head types, noting where each typically appears and what computational role it serves.

Pattern Comparison
PatternLayer PositionFunctionImportance
Previous-tokenEarly (L0-L2)Local context aggregationFoundation for n-gram statistics
Induction headsEarly-mid (L1-L6)Pattern copying from contextCore mechanism for in-context learning
Positional/sinkAll layersAttend to BOS or delimitersDefault when no specific pattern matches
Duplicate-tokenMid layersFlag repeated tokensImportant for copy and repetition tasks
SemanticLate layersAttend to related meaningTask-specific information retrieval
Warning

Attention weights show where the model "looks" but not what it "sees." High attention to a token does not necessarily mean that token is important for the prediction. For example, many heads attend strongly to the beginning-of-sequence token or punctuation marks, not because these tokens carry useful information, but because the model uses them as a "default" when no specific pattern is relevant. The value vectors determine what information is extracted, and the subsequent feed-forward layers further transform the representation. Jain and Wallace (2019) demonstrated that randomly permuting attention weights often does not change the model's output, confirming that attention alone is an unreliable explanation. Use attention visualization as a starting point for investigation, not as definitive evidence of model reasoning.

1.2 The Attention Debate: 2024-2025 Resolution

For years, the interpretability community was stuck in a binary argument: "attention is explanation" versus "attention is not explanation." Jain and Wallace (2019) argued the latter; Wiegreffe and Pinter (2019) pushed back with evidence that attention can be meaningful. By 2024-2025, a more nuanced consensus has emerged that transcends this binary framing. The short answer: attention weights alone are not explanations, but attention patterns combined with complementary methods (probing, ablation, circuit analysis) are genuinely informative.

Why does this resolution matter? If you dismiss attention entirely, you lose the cheapest and fastest window into model behavior. If you trust attention uncritically, you build explanations on shaky ground. The modern consensus gives practitioners a principled middle path: use attention as a hypothesis generator, then validate with causal methods.

Key Insight

The 2024-2025 consensus on attention interpretability rests on three findings. First, attention patterns are informative when combined with value-weighted analysis (OV circuits). What matters is not just where the model looks, but what information gets extracted through the value vectors. Second, activation patching (Section 18.2) can validate whether a specific attention head is causally responsible for a behavior, turning correlational attention patterns into causal evidence. Third, attention heads participate in larger circuits; interpreting a single head in isolation is like reading one sentence from a paragraph. The induction head circuit (two heads composing across layers) is the canonical example of why circuit-level analysis supersedes single-head analysis.

In practice, this consensus translates to a three-step workflow for using attention in interpretability. First, visualize attention to generate hypotheses ("Head L5H3 seems to track coreference"). Second, validate the hypothesis using ablation or activation patching ("zeroing out L5H3 breaks coreference resolution but not other tasks"). Third, connect the finding to the broader circuit by examining which other heads and MLP layers interact with L5H3 using tools like TransformerLens (see Section 18.2 for the tooling).

# Attention + ablation validation workflow
import transformer_lens
from transformer_lens import HookedTransformer
import torch

model = HookedTransformer.from_pretrained("gpt2-small")

prompt = "The nurse told the doctor that she was ready"
clean_logits, clean_cache = model.run_with_cache(prompt)

# Step 1: Identify candidate head via attention pattern
# (e.g., head L5H3 attends from "she" to "nurse")

# Step 2: Validate causally by ablating the head
def ablate_head(activation, hook):
 """Zero out a specific attention head's output."""
 activation[:, :, 3, :] = 0 # head index 3
 return activation

ablated_logits = model.run_with_hooks(
 prompt,
 fwd_hooks=[("blocks.5.attn.hook_z", ablate_head)]
)

# Step 3: Compare predictions with and without the head
she_token = model.to_single_token(" she")
nurse_token = model.to_single_token(" nurse")

clean_prob = torch.softmax(clean_logits[0, -1], dim=-1)
ablated_prob = torch.softmax(ablated_logits[0, -1], dim=-1)

print(f"With L5H3: P(nurse-related) = {clean_prob[nurse_token]:.4f}")
print(f"Without L5H3: P(nurse-related) = {ablated_prob[nurse_token]:.4f}")
# A significant drop confirms the head is causally important
With L5H3: P(nurse-related) = 0.0847 Without L5H3: P(nurse-related) = 0.0213
Code Fragment 18.1.2: Validating an attention-based hypothesis with ablation. If zeroing out a head changes the prediction, the attention pattern is causally meaningful, not just correlational.
Fun Fact

Probing classifiers are small models trained to extract specific information from a larger model's internal representations. It is like giving a brain scan to a neural network: you cannot ask it what it knows, but you can check whether the information is in there.

2. Probing Classifiers

Probing classifiers provide a more rigorous way to test what information is encoded in a model's hidden representations. The idea is simple: extract hidden states from a specific layer, freeze them, and train a lightweight classifier on top to predict some property of interest (part of speech, syntactic dependency, semantic role, entity type). If the classifier succeeds, the property is encoded in the hidden states.

Below, we implement both a linear and nonlinear probing classifier for testing what linguistic information is encoded in each transformer layer. Code Fragment 18.1.3 shows this approach in practice.

# Probing classifier for linguistic properties
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics import accuracy_score
import numpy as np

class LinearProbe(nn.Module):
 """A simple linear probe for testing representation content."""
 def __init__(self, hidden_dim, num_classes):
 super().__init__()
 self.classifier = nn.Linear(hidden_dim, num_classes)

 def forward(self, hidden_states):
 return self.classifier(hidden_states)

class MLPProbe(nn.Module):
 """A nonlinear probe with one hidden layer."""
 def __init__(self, hidden_dim, num_classes, probe_dim=256):
 super().__init__()
 self.net = nn.Sequential(
 nn.Linear(hidden_dim, probe_dim),
 nn.ReLU(),
 nn.Dropout(0.1),
 nn.Linear(probe_dim, num_classes),
 )

 def forward(self, hidden_states):
 return self.net(hidden_states)

def extract_hidden_states(model, tokenizer, texts, layer_idx):
 """Extract hidden states from a specific layer."""
 model.eval()
 all_hidden = []

 for text in texts:
 inputs = tokenizer(text, return_tensors="pt", truncation=True)
 with torch.no_grad():
 outputs = model(**inputs, output_hidden_states=True)

 # Get hidden states from the specified layer
 # Shape: (1, seq_len, hidden_dim)
 hidden = outputs.hidden_states[layer_idx]
 all_hidden.append(hidden.squeeze(0))

 return all_hidden

def train_probe(
 hidden_states, # list of (seq_len, hidden_dim) tensors
 labels, # list of (seq_len,) label tensors
 num_classes,
 probe_type="linear",
 epochs=10,
 lr=1e-3,
):
 """Train a probing classifier on frozen hidden states."""
 # Flatten all tokens into a single dataset
 X = torch.cat(hidden_states, dim=0)
 y = torch.cat(labels, dim=0)

 hidden_dim = X.shape[1]
 if probe_type == "linear":
 probe = LinearProbe(hidden_dim, num_classes)
 else:
 probe = MLPProbe(hidden_dim, num_classes)

 optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
 criterion = nn.CrossEntropyLoss()

 dataset = TensorDataset(X, y)
 loader = DataLoader(dataset, batch_size=256, shuffle=True)

 for epoch in range(epochs):
 total_loss = 0
 for batch_x, batch_y in loader:
 logits = probe(batch_x)
 loss = criterion(logits, batch_y)
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 total_loss += loss.item()

 if (epoch + 1) % 5 == 0:
 print(f"Epoch {epoch+1}: loss={total_loss/len(loader):.4f}")

 return probe

# Example: Probe for part-of-speech tags across layers
model = AutoModel.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# For each layer, train a probe and measure accuracy
layer_accuracies = {}
for layer in range(model.config.num_hidden_layers + 1):
 hidden = extract_hidden_states(model, tokenizer, train_texts, layer)
 probe = train_probe(hidden, pos_labels, num_pos_tags, "linear")
 # evaluate_probe: run probe on validation hidden states,
 # compare predictions to val_labels, return accuracy
 acc = evaluate_probe(probe, val_hidden, val_labels)
 layer_accuracies[layer] = acc
 print(f"Layer {layer}: POS accuracy = {acc:.3f}")
Epoch 5: loss=1.2043 Epoch 10: loss=0.5871 Layer 0: POS accuracy = 0.612 Layer 3: POS accuracy = 0.784 Layer 6: POS accuracy = 0.891 Layer 9: POS accuracy = 0.923 Layer 11: POS accuracy = 0.917
Production Alternative

The implementation above builds probing classifiers with a custom PyTorch training loop for pedagogical clarity. For quick probing experiments, scikit-learn (install: pip install scikit-learn) provides a faster workflow that avoids manual training loops:

# Control task for validating probe results
import random

def run_probing_experiment_with_control(
 model, tokenizer, texts, labels, num_classes, layer_idx
):
 """Run probing with control task to measure selectivity."""

 hidden = extract_hidden_states(model, tokenizer, texts, layer_idx)

 # Real task probe
 # val_hidden and val_labels are computed from a held-out set
 # using extract_hidden_states() and the same label pipeline
 real_probe = train_probe(hidden, labels, num_classes, "linear")
 real_acc = evaluate_probe(real_probe, val_hidden, val_labels)

 # Control task: shuffle labels to create random assignment
 control_labels = [
 torch.randint(0, num_classes, label.shape)
 for label in labels
 ]
 control_probe = train_probe(hidden, control_labels, num_classes, "linear")
 # val_control_labels: same random mapping applied to validation set
 control_acc = evaluate_probe(control_probe, val_hidden, val_control_labels)

 selectivity = real_acc - control_acc

 return {
 "real_accuracy": real_acc,
 "control_accuracy": control_acc,
 "selectivity": selectivity,
 "meaningful": selectivity > 0.1, # threshold
 }
Code Fragment 18.1.19: Probing classifier for linguistic properties

The lab exercise later in this section demonstrates this approach.

2.1 Control Tasks

A common criticism of probing is that a powerful enough probe might learn the task itself rather than reflecting what the model has learned. Control tasks (Hewitt and Liang, 2019) address this by training the same probe on a random labeling of the data. If the probe achieves high accuracy on both the real task and the control task, the probe is too powerful, and the result is not meaningful.

Key Insight

The selectivity of a probe (real accuracy minus control accuracy) is a better measure than raw accuracy. A linear probe that achieves 90% on POS tagging but 30% on random labels has selectivity of 60%, indicating the representation genuinely encodes POS information. An MLP probe that achieves 95% on POS tagging but 85% on random labels has selectivity of only 10%, suggesting the probe itself is doing most of the work.

Probing to Debug a Sentiment Classifier

Who: Data scientist at a consumer electronics company

Situation: A fine-tuned BERT model for product review sentiment consistently misclassified sarcastic reviews ("Great, another phone that dies by noon") as positive.

Problem: Standard error analysis showed the pattern but not the cause; the model appeared to rely on positive keywords while ignoring negation and context cues.

Dilemma: Adding more sarcastic examples to the training set helped marginally, but the team could not tell whether the model was learning surface heuristics or genuinely understanding sentiment.

Decision: They used linear probing across all 12 BERT layers to test whether sarcasm was represented in the model's internal activations, even if the classification head ignored it.

How: They labeled 500 reviews as sarcastic or literal, extracted hidden states from each layer, and trained linear probes. They also ran a control task with randomized labels to measure probe selectivity.

Result: The probe found sarcasm information was present in layers 8 through 11 (72% accuracy, selectivity 38%), confirming the model encoded sarcasm but the classification head failed to use it. Re-training with a weighted loss on sarcastic examples improved sarcasm detection from 41% to 79% F1.

Lesson: Probing reveals what information a model has versus what it uses; when a model "knows" something internally but predicts incorrectly, the fix is in the task head or loss function, not the representations. Code Fragment 18.1.4 shows this approach in practice.

# Using the tuned lens package
# pip install tuned-lens
from tuned_lens import TunedLens
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Load pre-trained tuned lens for GPT-2
tuned = TunedLens.from_model_and_pretrained(model)

text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
 outputs = model(**inputs, output_hidden_states=True)

hidden_states = outputs.hidden_states

# Apply tuned lens at each layer
for layer_idx in range(len(hidden_states) - 1):
 # Tuned lens applies a learned affine transform
 logits = tuned(hidden_states[layer_idx], layer_idx)
 probs = F.softmax(logits[0, -1], dim=-1)
 top_token = tokenizer.decode(probs.argmax())
 top_prob = probs.max().item()
 print(f"Layer {layer_idx:2d}: '{top_token}' (p={top_prob:.3f})")
Layer 0: 'the' (0.051) Layer 1: 'the' (0.068) Layer 2: 'a' (0.074) Layer 3: 'France' (0.092) Layer 6: 'Paris' (0.184) Layer 9: 'Paris' (0.387) Layer 11: 'Paris' (0.498)
Code Fragment 18.1.4: Evaluation pipeline computing metrics on model predictions. Automated metrics provide a quick signal; supplement them with qualitative review.

3. The Logit Lens

The logit lens (nostalgebraist, 2020) is a technique for inspecting what a transformer "thinks" at each intermediate layer. The idea is to take the hidden state at any layer and project it through the model's final unembedding matrix (the same matrix used to produce the output logits). This reveals the model's current "best guess" for the next token at each point in the computation. Figure 18.1.3 illustrates how predictions sharpen from noisy early layers to confident later layers. Code Fragment 18.1.5 shows this approach in practice.

Figure 18.1.3: The logit lens projects hidden states from each layer through ...
Figure 18.1.3: The logit lens projects hidden states from each layer through the unembedding matrix. Earlier layers produce noisy predictions that sharpen as computation progresses.

The following implementation projects each layer's hidden state through the unembedding matrix, revealing the model's intermediate "best guess" at each stage of computation.

Code Fragment 18.1.5 applies the logit lens.

# Logit Lens Implementation
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

def logit_lens(model, tokenizer, text, top_k=5):
 """
 Apply the logit lens to see what the model predicts
 at each intermediate layer.
 """
 inputs = tokenizer(text, return_tensors="pt")
 with torch.no_grad():
 outputs = model(**inputs, output_hidden_states=True)

 hidden_states = outputs.hidden_states # (num_layers+1,) tuple
 # Get the unembedding matrix
 unembed = model.lm_head.weight # (vocab_size, hidden_dim)

 tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
 last_pos = len(tokens) - 1 # predict next token after last

 print(f"Input: {text}")
 print(f"Predicting token after: '{tokens[last_pos]}'")
 print("-" * 50)

 for layer_idx, hidden in enumerate(hidden_states):
 # Project hidden state through unembedding
 # hidden shape: (1, seq_len, hidden_dim)
 logits = hidden[0, last_pos] @ unembed.T # (vocab_size,)
 probs = F.softmax(logits, dim=-1)

 top_probs, top_ids = probs.topk(top_k)
 top_tokens = [tokenizer.decode(tid) for tid in top_ids]

 top_str = ", ".join(
 f"'{t}' ({p:.3f})" for t, p in zip(top_tokens, top_probs)
 )
 print(f"Layer {layer_idx:2d}: {top_str}")

# Run logit lens on GPT-2
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
logit_lens(model, tokenizer, "The Eiffel Tower is in")
Input: The Eiffel Tower is in Predicting token after: 'in' -------------------------------------------------- Layer 0: 'the' (0.042), 'a' (0.031), 'of' (0.028) Layer 3: 'the' (0.089), 'France' (0.045), 'a' (0.033) Layer 6: 'Paris' (0.082), 'France' (0.071), 'the' (0.055) Layer 9: 'Paris' (0.215), 'France' (0.098), 'the' (0.044) Layer 11: 'Paris' (0.412), 'France' (0.087), 'the' (0.031)
Code Fragment 18.1.5: Embedding generation for converting text into dense vector representations. These vectors capture semantic meaning, enabling similarity search and clustering.
Run This Now

Copy the logit lens code above into a notebook and try it on different prompts: "The CEO of Apple is," "Water freezes at," or "The largest planet is." Watch how the correct answer emerges at different layers for different types of knowledge. Some facts crystallize early (well-known associations); others take more layers to resolve (multi-step reasoning).

3.1 The Tuned Lens

The tuned lens (Belrose et al., 2023) improves on the logit lens by training a learned affine transformation for each layer. The raw logit lens assumes that intermediate representations are approximately in the same space as the final layer, which is only roughly true. The tuned lens trains a small per-layer probe to account for the differences in representation spaces across layers.

Note

The tuned lens provides cleaner, more interpretable results than the raw logit lens, especially in early layers where representations are furthest from the output space. The training cost is minimal (a single linear layer per transformer layer), and pre-trained tuned lens parameters are available for popular models via the tuned-lens Python package. Code Fragment 18.1.6 shows this approach in practice.

Code Fragment 18.1.6: Tokenization pipeline converting raw text into model-ready input IDs. The tokenizer handles special tokens, padding, and truncation automatically.
Self-Check
1. Why should attention weights not be interpreted as "importance" scores?
Show Answer
Attention weights indicate how much information flows between token positions, but they do not account for the value vectors (what information is actually transferred) or the subsequent feed-forward layer transformations. High attention to a token means the model "looks" at it, not that it is important for the prediction. Attention can also serve functional roles (like "no-op" sink heads attending to BOS) unrelated to semantic importance.
2. What is the purpose of a control task in probing experiments?
Show Answer
A control task trains the same probe architecture on randomly shuffled labels. It measures how much accuracy comes from the probe's own capacity (memorizing the mapping) versus the information actually encoded in the representations. The selectivity (real accuracy minus control accuracy) gives a more honest measure of representation quality. If both real and control accuracy are high, the probe is too powerful.
3. What does the logit lens reveal about how transformers process information?
Show Answer
The logit lens shows that transformers build up their predictions incrementally across layers. Early layers produce noisy, uncertain predictions. Middle layers begin to converge on the correct answer. Late layers refine the final prediction. This reveals that transformer computation is a gradual refinement process, not a single-step computation. The "residual stream" view (Section 18.2) formalizes this as iterative updates to a shared representation.
4. What are induction heads, and why are they important?
Show Answer
Induction heads are attention heads that implement a copying mechanism: when they see a pattern [A][B]...[A], they predict [B] will follow. They are a two-head circuit (previous-token head + induction head) and are believed to be the primary mechanism underlying in-context learning. They enable transformers to recognize and continue patterns from the context window without any gradient updates.
5. How does the tuned lens improve on the standard logit lens?
Show Answer
The standard logit lens assumes all layers share approximately the same representation space, which is only roughly true. The tuned lens trains a small affine transformation per layer to map each layer's representation into the output space more accurately. This produces cleaner predictions, especially in early layers where the representation space differs most from the final layer. The per-layer affine transform has minimal parameters and can be pre-trained offline.
Try It Yourself

Pick a GPT-2 attention head (for example, Layer 5, Head 1) and visualize its attention pattern on three different sentences: one with a pronoun reference ("The cat chased the mouse because it was hungry"), one with a repeated word ("the dog chased the dog"), and one with a list ("I like apples, bananas, and oranges"). Can you classify the head's pattern type using the taxonomy from this section?

✅ Key Takeaways

Hands-On Lab: Visualize Attention Patterns and Build a Probing Classifier

Duration: ~60 minutes Intermediate

Objective

Extract and visualize attention patterns from GPT-2, identify interpretable attention heads, and train a linear probing classifier to test what linguistic information is encoded in hidden states at each layer.

What You'll Practice

  • Extracting attention weight matrices from transformer layers
  • Creating attention heatmaps with matplotlib and seaborn
  • Identifying attention head specialization (positional, syntactic, semantic)
  • Building linear probes to detect part-of-speech information in hidden states

Setup

The following cell installs the required packages and configures the environment for this lab.

pip install transformers torch matplotlib seaborn numpy scikit-learn

Steps

Step 1: Extract attention weights from GPT-2

Run a sentence through GPT-2 and capture attention weights from every layer and head.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
 model_name, output_attentions=True)
model.eval()

text = "The cat sat on the mat because it was comfortable"
inputs = tokenizer(text, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

with torch.no_grad():
 outputs = model(**inputs)

# outputs.attentions: tuple of (num_layers,) tensors
attentions = outputs.attentions
print(f"Layers: {len(attentions)}")
print(f"Shape per layer: {attentions[0].shape}")
print(f"Tokens: {tokens}")

# TODO: Stack into (layers, heads, seq, seq) numpy array
all_attn = torch.stack(attentions).squeeze(1).numpy()
print(f"All attention shape: {all_attn.shape}")
Layers: 12 Shape per layer: torch.Size([1, 12, 10, 10]) Tokens: ['The', 'Ġcat', 'Ġsat', 'Ġon', 'Ġthe', 'Ġmat', 'Ġbecause', 'Ġit', 'Ġwas', 'Ġcomfortable'] All attention shape: (12, 12, 10, 10)
Code Fragment 18.1.16: Extracting attention weights from every layer and head of GPT-2 by enabling output_attentions=True. The resulting 4D array (layers, heads, query position, key position) provides the raw data for visualizing how the model routes information.
Hint

Use torch.stack(attentions).squeeze(1).numpy() to get a (12, 12, seq_len, seq_len) array for GPT-2 (12 layers, 12 heads).

Step 2: Create attention heatmaps

Visualize attention patterns for all 12 heads in layer 0 to see different attention strategies.

import matplotlib.pyplot as plt
import seaborn as sns

def plot_attention_head(attn_matrix, tokens, layer, head, ax):
 sns.heatmap(attn_matrix, xticklabels=tokens, yticklabels=tokens,
 cmap='Blues', ax=ax, vmin=0, vmax=1)
 ax.set_title(f'L{layer}H{head}', fontsize=9)
 ax.tick_params(labelsize=6)

# TODO: Create a 3x4 grid for all 12 heads in layer 0
fig, axes = plt.subplots(3, 4, figsize=(20, 15))
for h in range(12):
 row, col = h // 4, h % 4
 plot_attention_head(all_attn[0, h], tokens, 0, h, axes[row, col])

plt.tight_layout()
plt.savefig('attention_heads_layer0.png', dpi=100)
plt.show()
Code Fragment 18.1.15: Plotting attention heatmaps for all 12 heads in layer 0 reveals that different heads learn distinct strategies: some attend locally (diagonal patterns), others focus on specific token types (vertical stripes), and still others distribute attention broadly.
Hint

Look for heads showing: diagonal lines (positional/local attention), vertical stripes (attending to specific tokens like periods or commas), or broad patterns (attending broadly). These reveal different computational strategies.

Step 3: Analyze coreference attention

Examine which token "it" attends to most strongly, looking for coreference resolution.

# Find token positions
it_pos = tokens.index("Ġit") if "Ġit" in tokens else tokens.index("it")
cat_pos = tokens.index("Ġcat") if "Ġcat" in tokens else tokens.index("cat")
print(f"'it' at position {it_pos}, 'cat' at position {cat_pos}")

# Find heads that show strong it-to-cat attention
print("\nAttention from 'it' to 'cat' by layer:")
for layer in range(len(attentions)):
 for head in range(12):
 attn_score = all_attn[layer, head, it_pos, cat_pos]
 if attn_score > 0.1:
 print(f" Layer {layer}, Head {head}: {attn_score:.3f}")
'it' at position 7, 'cat' at position 1 Attention from 'it' to 'cat' by layer: Layer 5, Head 1: 0.312 Layer 5, Head 10: 0.187 Layer 8, Head 6: 0.142 Layer 10, Head 7: 0.223
Code Fragment 18.1.14: Find token positions
Hint

GPT-2 tokenization adds a "Ġ" prefix for tokens preceded by a space. If the exact token is not found, print tokens to see the exact tokenization.

Step 4: Build a linear probing classifier

Train a linear classifier on hidden states to detect part-of-speech tags, testing what syntactic information each layer encodes.

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

probe_data = [
 ("The dog chased the cat", ["DET","NOUN","VERB","DET","NOUN"]),
 ("A bird flew over trees", ["DET","NOUN","VERB","ADP","NOUN"]),
 ("She quickly ran home today", ["PRON","ADV","VERB","NOUN","ADV"]),
 ("Big waves crashed on shore", ["ADJ","NOUN","VERB","ADP","NOUN"]),
 ("They slowly walked to school", ["PRON","ADV","VERB","ADP","NOUN"]),
 ("My friend reads many books", ["DET","NOUN","VERB","DET","NOUN"]),
 ("The tall man opened doors", ["DET","ADJ","NOUN","VERB","NOUN"]),
 ("We often eat fresh food", ["PRON","ADV","VERB","ADJ","NOUN"]),
]

def extract_hidden(model, tokenizer, text, layer):
 inputs = tokenizer(text, return_tensors="pt")
 with torch.no_grad():
 out = model(**inputs, output_hidden_states=True)
 return out.hidden_states[layer][0].numpy()

# TODO: For each layer, collect hidden states and POS labels,
# then train a LogisticRegression probe and report accuracy.
for layer in [0, 3, 6, 9, 11]:
 X, y = [], []
 for text, pos_tags in probe_data:
 hidden = extract_hidden(model, tokenizer, text, layer)
 toks = tokenizer.convert_ids_to_tokens(
 tokenizer(text)['input_ids'])
 for i, tag in enumerate(pos_tags):
 if i < hidden.shape[0]:
 X.append(hidden[i])
 y.append(tag)

 X = np.array(X)
 clf = LogisticRegression(max_iter=1000, random_state=42)
 scores = cross_val_score(clf, X, y, cv=3, scoring='accuracy')
 print(f"Layer {layer}: POS probe accuracy = "
 f"{scores.mean():.3f} (+/- {scores.std():.3f})")
Layer 0: POS probe accuracy = 0.475 (+/- 0.082) Layer 3: POS probe accuracy = 0.712 (+/- 0.064) Layer 6: POS probe accuracy = 0.837 (+/- 0.051) Layer 9: POS probe accuracy = 0.863 (+/- 0.043) Layer 11: POS probe accuracy = 0.825 (+/- 0.058)
Code Fragment 18.1.13: Training a logistic regression probe on hidden states at each layer to predict part-of-speech tags. Accuracy peaks in middle layers (6 and 9), showing that syntactic information is most accessible there before later layers shift focus toward next-token prediction.
Hint

Accuracy should increase in middle layers (3 to 6) where syntactic features are most prominent. Earlier layers capture more surface-level features; later layers focus on next-token prediction.

Expected Output

  • A grid of 12 attention heatmaps showing distinct head patterns (positional, broad, specific)
  • At least one head showing coreference attention from "it" to "cat"
  • POS probing accuracy increasing from ~40% at layer 0 to ~60 to 70% at middle layers

Stretch Goals

  • Implement the logit lens: project hidden states at each layer through the final unembedding matrix to see predicted tokens at each layer
  • Probe for other properties: named entities, sentence boundaries, or syntactic depth
  • Compare patterns between GPT-2 small and GPT-2 medium
Complete Solution
import torch, numpy as np, matplotlib.pyplot as plt, seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2", output_attentions=True)
model.eval()

text = "The cat sat on the mat because it was comfortable"
inputs = tokenizer(text, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
with torch.no_grad(): outputs = model(**inputs)
all_attn = torch.stack(outputs.attentions).squeeze(1).numpy()

# Heatmaps
fig, axes = plt.subplots(3, 4, figsize=(20, 15))
for h in range(12):
 sns.heatmap(all_attn[0,h], xticklabels=tokens, yticklabels=tokens,
 cmap='Blues', ax=axes[h//4, h%4], vmin=0, vmax=1)
 axes[h//4, h%4].set_title(f'L0H{h}', fontsize=9)
 axes[h//4, h%4].tick_params(labelsize=6)
plt.tight_layout(); plt.savefig('attention_heads.png'); plt.show()

# Coreference
it_pos = tokens.index("Ġit") if "Ġit" in tokens else tokens.index("it")
cat_pos = tokens.index("Ġcat") if "Ġcat" in tokens else tokens.index("cat")
for l in range(12):
 for h in range(12):
 a = all_attn[l, h, it_pos, cat_pos]
 if a > 0.1: print(f"L{l}H{h}: it->cat = {a:.3f}")

# Probing
probe_data = [
 ("The dog chased the cat", ["DET","NOUN","VERB","DET","NOUN"]),
 ("A bird flew over trees", ["DET","NOUN","VERB","ADP","NOUN"]),
 ("She quickly ran home today", ["PRON","ADV","VERB","NOUN","ADV"]),
 ("Big waves crashed on shore", ["ADJ","NOUN","VERB","ADP","NOUN"]),
 ("They slowly walked to school", ["PRON","ADV","VERB","ADP","NOUN"]),
 ("My friend reads many books", ["DET","NOUN","VERB","DET","NOUN"]),
 ("The tall man opened doors", ["DET","ADJ","NOUN","VERB","NOUN"]),
 ("We often eat fresh food", ["PRON","ADV","VERB","ADJ","NOUN"]),
]

for layer in [0, 3, 6, 9, 11]:
 X, y = [], []
 for txt, tags in probe_data:
 inp = tokenizer(txt, return_tensors="pt")
 with torch.no_grad():
 out = model(**inp, output_hidden_states=True)
 h = out.hidden_states[layer][0].numpy()
 for i, tag in enumerate(tags):
 if i < h.shape[0]: X.append(h[i]); y.append(tag)
 scores = cross_val_score(LogisticRegression(max_iter=1000), np.array(X), y, cv=3)
 print(f"Layer {layer}: {scores.mean():.3f} +/- {scores.std():.3f}")
L5H1: it->cat = 0.312 L5H10: it->cat = 0.187 L8H6: it->cat = 0.142 L10H7: it->cat = 0.223 Layer 0: 0.475 +/- 0.082 Layer 3: 0.712 +/- 0.064 Layer 6: 0.837 +/- 0.051 Layer 9: 0.863 +/- 0.043 Layer 11: 0.825 +/- 0.058
Code Fragment 18.1.12: Complete reference solution combining attention extraction, heatmap visualization, coreference analysis, and linear probing into a single runnable script. Use this to verify your step-by-step results or as a starting point for further experiments.
Research Frontier

Attention analysis is being refined through causal interventions (activation patching) that go beyond correlational attention pattern visualization to identify which attention heads causally drive specific model behaviors. Research on representation probing is revealing that LLMs develop surprisingly structured internal representations, including linear representations of truth, spatial relationships, and temporal concepts. The frontier challenge is scaling probing techniques to the largest models, where the sheer number of components makes exhaustive analysis impractical and statistical methods for identifying important circuits become essential.

Exercises

Exercise 18.1.1: Why interpretability matters Conceptual

Give three concrete reasons why understanding how an LLM arrives at its outputs is important. For each reason, describe a scenario where lack of interpretability could cause harm.

Answer Sketch

1. Safety: if a model gives medical advice, we need to verify its reasoning is sound, not just that the answer sounds plausible. Harm: a model confidently recommends a drug interaction it hallucinated, and the lack of interpretability means no one catches it until a patient is affected. 2. Debugging: when a model produces biased outputs, interpretability helps identify which training data or architectural patterns cause the bias. Harm: a hiring model systematically disadvantages certain groups, but without interpretability tools, the cause cannot be diagnosed. 3. Trust and regulation: many jurisdictions require explainability for AI decisions. Harm: a financial model denies a loan but cannot explain why, violating right-to-explanation regulations.

Exercise 18.1.2: Probing classifiers Conceptual

Explain what a probing classifier is and how it is used to study what information is encoded in a model's hidden representations. What does it mean if a linear probe achieves high accuracy for a particular property?

Answer Sketch

A probing classifier is a simple model (typically linear) trained to predict a property (e.g., part of speech, sentiment, entity type) from a model's hidden states at specific layers. If a linear probe achieves high accuracy, it means the information about that property is encoded in a linearly separable way in the hidden states. This suggests the model has learned to represent that property explicitly. Caution: high probe accuracy does not prove the model uses this information for its predictions; it only shows the information is present. The simplicity of the probe matters: a complex probe might 'decode' information that is not really represented accessibly.

Exercise 18.1.3: Attention pattern analysis Coding

Using a pre-trained model from Hugging Face, extract attention patterns for the sentence 'The doctor told the nurse that she was late.' Visualize attention from 'she' across all heads and layers. Does the model show evidence of resolving the pronoun correctly?

Answer Sketch

Use model(input_ids, output_attentions=True) and index into the attention matrices for the position of 'she'. Look for heads where 'she' strongly attends to either 'doctor' or 'nurse'. Some heads will show syntactic attention (attending to the nearest noun), while others may show coreference-like patterns. The 'correct' resolution is ambiguous in this sentence (could refer to either), so the interesting finding is whether different heads disagree, reflecting the ambiguity. Use matplotlib or bertviz for visualization.

Exercise 18.1.4: Logit lens technique Conceptual

The 'logit lens' applies the model's output unembedding matrix to intermediate layer representations. Explain what this reveals and why it is useful for understanding how a model builds up its prediction across layers.

Answer Sketch

The logit lens projects hidden states from intermediate layers into vocabulary space, showing what the model would predict if it stopped at that layer. In early layers, the projection may show the input token or generic high-frequency tokens. As you move through layers, the prediction gradually shifts toward the final answer. This reveals the 'computation trajectory': you can see at which layer the model commits to its answer, where factual knowledge is retrieved, and where syntactic structure is resolved. It is useful for debugging wrong predictions (at which layer does the model go astray?) and understanding the division of labor across layers.

Exercise 18.1.5: Limitations of interpretability methods Discussion

Critique two popular interpretability methods (attention visualization and feature attribution) by identifying their key assumptions and failure modes. Are we actually understanding the model, or just creating plausible-looking explanations?

Answer Sketch

Attention visualization assumes attention weights indicate importance, but attention is a learned routing mechanism, not an explanation. A head might attend strongly to a token because it needs to suppress it, not because it is important. Feature attribution (gradient-based or perturbation-based) assumes independence of features and measures local sensitivity, but neural networks use features combinatorially. Both methods provide post-hoc explanations that may not reflect the model's actual computation. The risk: interpretability methods may give practitioners false confidence that they understand the model, when in fact they are pattern-matching on visualizations that confirm their prior beliefs. More rigorous approaches (causal interventions, mechanistic interpretability) attempt to establish causal rather than correlational relationships.

Tip: Use Logit Lens as a Quick Diagnostic

To understand what a model is "thinking" at intermediate layers, project hidden states through the final unembedding matrix. This logit lens technique takes minutes to implement and often reveals surprising early-layer predictions that help debug unexpected outputs.

What Comes Next

In the next section, Section 18.2: Mechanistic Interpretability, we dive into mechanistic interpretability, reverse-engineering the algorithms that neural networks learn internally.

Bibliography
Attention Analysis

Clark, K., Khandelwal, U., Levy, O., & Manning, C. D. (2019). What Does BERT Look At? An Analysis of BERT's Attention. ACL Workshop BlackboxNLP.

Systematically maps BERT's attention heads to linguistic phenomena such as coreference, syntax, and separator tokens. This is the go-to empirical study for anyone wanting to understand what individual attention heads specialize in across layers.

Attention Patterns

Jain, S. & Wallace, B. C. (2019). Attention is not Explanation. NAACL 2019.

Demonstrates that attention weights often do not correlate with gradient-based feature importance and can be permuted without changing outputs. A critical read for anyone tempted to use raw attention maps as faithful explanations of model predictions.

Cautionary

Voita, E., Talbot, D., Moiseev, F., Sennrich, R., & Titov, I. (2019). Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019.

Identifies three types of specialized attention heads (positional, syntactic, rare-word) and shows most heads can be pruned with minimal quality loss. Researchers interested in attention head taxonomy and model compression through head pruning should start here.

Head Specialization
Probing & Visualization

Hewitt, J. & Manning, C. D. (2019). A Structural Probe for Finding Syntax in Word Representations. NAACL 2019.

Introduces geometric probes that recover parse tree structure from embedding spaces via linear transformations. This paper is essential for understanding how syntactic information is geometrically encoded in contextualized representations.

Structural Probing

Belinkov, Y. (2022). Probing Classifiers: Promises, Shortcomings, and Advances. Computational Linguistics, 48(1), 207-219.

A thorough review of the probing methodology, covering the selectivity problem, control tasks, and information-theoretic alternatives. Anyone designing probing experiments should consult this paper to avoid common pitfalls and choose appropriate baselines.

Methodology Review

Vig, J. (2019). A Multiscale Visualization of Attention in the Transformer Model. ACL System Demonstrations.

Presents BertViz, a tool for interactive attention visualization at the head, model, and neuron levels. Practitioners who want to visually inspect attention patterns during debugging or analysis will find this tool immediately useful in their workflows.

Visualization Tool