Part 4: Training and Adapting
Chapter 16: Distillation and Merging

Knowledge Distillation for LLMs

The goal is not to build a smaller model, but to build a smaller model that remembers what the bigger one learned.

Distill Distill, Memory-Retentive AI Agent
Big Picture

Knowledge distillation is the art of making small models behave like large ones. A 70B-parameter teacher model contains vast knowledge but is expensive to serve. By training a smaller student model to mimic the teacher's output distribution (not just its final answers), the student can inherit much of the teacher's capability at a fraction of the inference cost. This technique has produced some of the most remarkable results in the LLM space: Microsoft's Phi-3 models distilled from GPT-4 demonstrate that a 3.8B model can match models 10x its size. DeepSeek distilled its R1 reasoning model into compact variants that retain strong chain-of-thought abilities. The pre-training objectives from Section 06.2 explain how these teacher models acquired their knowledge in the first place.

Fun Note

The term "distillation" comes from chemistry: extracting the essence of a substance by heating it and collecting what evaporates. Hinton, who coined the term for neural networks, liked the metaphor because the teacher's soft probability distribution contains richer information than hard labels alone. When a teacher says "this is 70% cat, 25% tiger, 5% dog," the student learns that cats and tigers are related. Hard labels ("cat") throw away that nuance entirely.

Prerequisites

This section builds on fine-tuning from Section 14.1: When and Why to Fine-Tune and pre-training covered in Section 06.1: The Landmark Models.

A master chef (teacher model) guiding a junior chef (student model) to recreate a complex dish in a simpler way
Figure 16.1.1: Knowledge distillation: the master chef teaches the apprentice not just the recipe, but the subtle intuitions behind every step. The student's version might be smaller, but it captures the essence.

1. Classical Distillation Framework

1.1 The Teacher-Student Paradigm

Key Insight

Knowledge distillation mirrors a longstanding principle in pedagogy and cognitive apprenticeship theory (Collins, Brown, and Newman, 1989): experts transfer knowledge not just through explicit facts but through the structure of their reasoning. When a teacher model produces a probability distribution that assigns 0.6 to "happy" and 0.2 to "glad," the student learns not just the correct answer but the semantic neighborhood around it. This is analogous to how a master craftsperson teaches an apprentice not by listing rules but by demonstrating the full nuance of their decision-making process. In information-theoretic terms, the soft distribution contains far more bits of information than a hard label: a one-hot label carries log(V) bits (where V is vocabulary size), while a soft distribution carries the full entropy of the teacher's uncertainty. Hinton's temperature parameter directly controls how much of this "dark knowledge" is revealed, by smoothing the distribution to expose inter-class relationships that would otherwise be hidden by the winner-take-all nature of low-temperature softmax.

Knowledge distillation (Hinton et al., 2015) trains a smaller "student" model to match the output probability distribution of a larger "teacher" model, rather than training the student solely on hard ground-truth labels. Recall from Chapter 06 that a language model produces logits (raw, unnormalized scores for each token in the vocabulary) which are then converted to probabilities via the softmax function. The key insight of distillation is that the teacher's probability distribution over all possible tokens contains far richer information than a single correct answer. When the teacher assigns 0.6 probability to "happy," 0.2 to "glad," and 0.1 to "joyful," these "soft" probabilities encode semantic relationships that hard labels cannot convey. Figure 16.1.1 shows this teacher-student training setup.

Teacher (70B) Large, expensive, accurate Student (7B) Small, fast, efficient Soft Targets (T > 1) Training Input Distillation Loss L = α · KL(teacher_soft || student_soft) + (1 - α) · CE(labels, student_hard) Soft target loss (T² scaled) Hard label loss (standard CE)
Figure 16.1.2: The student learns from both the teacher's soft probability distribution and the ground-truth hard labels.

1.2 Temperature and Soft Targets

The temperature parameter T controls how "soft" the teacher's output distribution becomes. At T=1 (normal softmax), the teacher's distribution is peaked on the most likely token. As T increases, the distribution becomes smoother, revealing the relative probabilities of less likely tokens. This "dark knowledge" in the non-top predictions encodes the teacher's understanding of semantic similarity and uncertainty.

Visualization of how temperature scaling transforms a sharp softmax distribution into a softer one, revealing dark knowledge in non-top predictions
Figure 16.1.5: Temperature scaling transforms the teacher's sharp output distribution into a softer one. At low temperatures the top prediction dominates; at higher temperatures the relative probabilities of alternative tokens become visible, transferring richer information to the student.

As Figure 16.1.5 illustrates, increasing the temperature gradually exposes the teacher's uncertainty across the full vocabulary, letting the student learn from the relationships between alternatives rather than just the top prediction.

Key Insight: The Master Chef and the Apprentice

Think of knowledge distillation as a master chef teaching an apprentice. The master (teacher model) does not just tell the apprentice the correct dish; they share the probability distribution over all possible dishes ('this is 70% likely a risotto, 20% a pilaf, 10% a paella'). These soft labels carry richer information than hard labels ('this is a risotto') because they reveal the master's uncertainty and the relationships between options. The apprentice (student model) learns not just the answers but the teacher's reasoning patterns, often achieving surprisingly close performance at a fraction of the size.

Fun Fact

Knowledge distillation was first proposed by Hinton, Vinyals, and Dean in 2015. The core insight (that a teacher model's soft probability distribution carries more information than hard labels) is still the foundation of every distillation method used today.

The softmax with temperature is computed as:

$$p_{i} = \exp(z_{i} / T) / \Sigma_{j} \exp(z_{j} / T)$$

where $z_{i}$ are the logits (pre-softmax values). Common temperature values range from 1.5 to 4.0. To see why temperature matters, consider logits [5.0, 2.0, 0.5] for three tokens. At T=1, softmax produces [0.92, 0.05, 0.01], a sharply peaked distribution that barely distinguishes the non-top tokens. At T=2, the distribution softens to [0.72, 0.18, 0.10], revealing that the second token is more plausible than the third. At T=4, we get [0.55, 0.26, 0.19], exposing even more of the teacher's uncertainty. Higher temperatures expose more of the teacher's knowledge but also introduce more noise. The distillation loss is scaled by T² to compensate for the gradient magnitude reduction caused by softening. Code Fragment 16.1.2 shows this approach in practice.

# Compute the KL-divergence distillation loss between teacher and student
# Soft targets from the teacher transfer more information than hard labels
import torch

import torch.nn as nn

import torch.nn.functional as F

class DistillationLoss(nn.Module):

 """Combined distillation and task loss for LLM training."""

 def __init__(self, temperature=2.0, alpha=0.5):

 super().__init__()

 self.temperature = temperature

 self.alpha = alpha # Weight for distillation vs hard label loss

 self.kl_loss = nn.KLDivLoss(reduction="batchmean")

 def forward(self, student_logits, teacher_logits, labels):

 # Soft targets: soften both distributions with temperature

 T = self.temperature

 student_soft = F.log_softmax(student_logits / T, dim=-1)

 teacher_soft = F.softmax(teacher_logits / T, dim=-1)

 # KL divergence loss (scaled by T^2)

 distill_loss = self.kl_loss(student_soft, teacher_soft) * (T ** 2)

 # Hard label cross-entropy loss

 hard_loss = F.cross_entropy(

 student_logits.view(-1, student_logits.size(-1)),

 labels.view(-1),

 ignore_index=-100,

 )

 # Combined loss

 return self.alpha * distill_loss + (1 - self.alpha) * hard_loss
{'train_loss': 1.3842, 'train_runtime': 284.7, 'train_samples_per_second': 7.02}
Distillation complete. Evaluate student against teacher.
Code Fragment 16.1.1: Compute the KL-divergence distillation loss between teacher and student
Key Insight

The temperature parameter is critical. At T=1, the teacher's distribution is so peaked that it provides little more information than a hard label. At T=4, the distribution is smooth enough to reveal semantic relationships between tokens. However, too high a temperature washes out the signal entirely. Start with T=2.0 and tune based on validation performance. The T² scaling factor in the loss ensures consistent gradient magnitudes regardless of temperature.

Real-World Scenario
Distilling GPT-4 into a Domain-Specific 1.5B Model

Who: ML platform team at an e-commerce company

Situation: Product search used GPT-4 API calls to rewrite user queries into structured filters (brand, size, color, price range), processing 2 million queries per day.

Problem: GPT-4 API costs exceeded $18,000 per month, and latency averaged 1.2 seconds per query, degrading the search experience.

Dilemma: A smaller fine-tuned model could reduce cost and latency, but the team lacked labeled training data for the structured query format.

Decision: They used black-box distillation: GPT-4 generated 200,000 query-to-filter pairs from production logs, then they fine-tuned Qwen 1.5B on that synthetic dataset.

How: The student model was trained with standard cross-entropy on GPT-4 outputs (hard labels only, since logits were unavailable). Temperature-scaled soft labels were approximated by asking GPT-4 to output top-5 alternative parsings with confidence scores.

Result: The distilled 1.5B model matched GPT-4 accuracy on 94% of queries, reduced latency to 45ms (26x faster), and cut monthly costs to $400 for self-hosted inference.

Lesson: Black-box distillation from API models is a practical path to production; generating diverse synthetic training data compensates for the lack of soft probability distributions.

Why does distillation work so well? The core intuition is that a teacher model's soft probability distribution contains far more information than hard labels. When a teacher assigns probabilities [0.6, 0.2, 0.1, 0.05, 0.05] to five candidate tokens, it is implicitly encoding semantic relationships: "happy" and "glad" are closely related (both get high probability), while "table" is semantically distant (near-zero probability). A student trained on these soft targets learns these relationships in every training step, effectively receiving a compressed version of the teacher's knowledge about language structure. Hard labels ("the answer is happy") convey none of this relational information. This is why distilled models can be surprisingly capable despite being 10x smaller: they inherit the teacher's understanding of token relationships, not just its final predictions.

Key Insight

Distillation is fundamentally about compressing inference cost, not training cost. The distillation process itself is expensive (you must run the teacher model on your entire training set). The payoff comes at serving time, where the smaller student model delivers similar quality at a fraction of the latency and cost. This connects directly to inference optimization in Chapter 8: distillation and quantization are complementary techniques, and production deployments often apply both (distill first, then quantize the student).

2. White-Box vs. Black-Box Distillation

The distillation approach depends on what access you have to the teacher model. White-box distillation requires access to the teacher's internal logits; black-box distillation works only with the teacher's text outputs.

Side-by-side comparison of white-box distillation with full logit access versus black-box distillation using only text outputs from an API
Figure 16.1.6: White-box distillation gives the student access to the teacher's full probability distribution, while black-box distillation works only with the teacher's text outputs. The richer signal in white-box mode generally produces higher-quality students.

Figure 16.1.6 contrasts the two paradigms visually. The comparison table below summarizes their key differences in practice.

2. White-Box vs. Black-Box Distillation Intermediate
AspectWhite-Box DistillationBlack-Box Distillation
Teacher accessFull model weights and logitsAPI outputs (text only)
Loss signalKL (Kullback-Leibler) divergence on full distributionCross-entropy on generated text
Information richnessVery high (full probability distribution)Lower (only top-1 output)
Typical teachersOpen-weight models (Llama, Mistral)API models (GPT-4, Claude)
Quality ceilingHigher (more teacher knowledge transferred)Lower (limited to surface behavior)
ScalabilityLimited by GPU memory for teacherLimited by API cost and rate limits

In practice, most LLM distillation today is black-box, simply because the most capable teachers (GPT-4, Claude) are only available through APIs. The remainder of this section examines both approaches in detail.

2.1 White-Box Distillation

When you have full access to the teacher's weights and logits, white-box distillation transfers the maximum amount of information. The following code fragment shows a complete training loop that loads both teacher and student, generates soft targets from the teacher, and trains the student on the combined loss.

# Compute the KL-divergence distillation loss between teacher and student
# Soft targets from the teacher transfer more information than hard labels
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

from torch.utils.data import DataLoader

def white_box_distillation(

 teacher_model_id: str,

 student_model_id: str,

 train_dataset,

 temperature: float = 2.0,

 alpha: float = 0.5,

 epochs: int = 3,

 lr: float = 2e-5,

):

 """Train student to match teacher logit distribution."""

 # Load teacher (frozen, in eval mode)

 teacher = AutoModelForCausalLM.from_pretrained(

 teacher_model_id, torch_dtype=torch.bfloat16, device_map="auto"

 )

 teacher.eval()

 for param in teacher.parameters():

 param.requires_grad = False

 # Load student (trainable)

 student = AutoModelForCausalLM.from_pretrained(

 student_model_id, torch_dtype=torch.bfloat16, device_map="auto"

 )

 loss_fn = DistillationLoss(temperature=temperature, alpha=alpha)

 optimizer = torch.optim.AdamW(student.parameters(), lr=lr)

 dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

 for epoch in range(epochs):

 student.train()

 total_loss = 0

 for batch in dataloader:

 input_ids = batch["input_ids"].to(student.device)

 labels = batch["labels"].to(student.device)

 # Get teacher logits (no gradient)

 with torch.no_grad():

 teacher_out = teacher(input_ids=input_ids)

 teacher_logits = teacher_out.logits

 # Get student logits

 student_out = student(input_ids=input_ids)

 student_logits = student_out.logits

 # Compute combined loss

 loss = loss_fn(student_logits, teacher_logits, labels)

 loss.backward()

 optimizer.step()

 optimizer.zero_grad()

 total_loss += loss.item()

 print(f"Epoch {epoch+1}: avg_loss = {total_loss/len(dataloader):.4f}")

 return student
Code Fragment 16.1.2: White-box distillation pipeline that loads a frozen teacher and a trainable student, then optimizes the student to match the teacher's logit distribution. The teacher remains in eval mode throughout to provide stable soft targets for each training batch.

2.2 Black-Box Distillation

When the teacher is an API model (GPT-4, Claude, Gemini), you cannot access logits. Instead, you generate a large dataset of high-quality (input, output) pairs from the teacher, then fine-tune the student on these pairs using standard supervised training. The quality of the distilled student depends heavily on the diversity and quality of the generated training data. Code Fragment 16.1.4 shows this in practice.

# Generate teacher labels asynchronously for large-scale distillation
# Async batching maximizes throughput when labeling with an API-based teacher
import asyncio

from openai import AsyncOpenAI

import json

client = AsyncOpenAI()

async def generate_distillation_data(

 prompts: list[str],

 teacher_model: str = "gpt-4o",

 system_prompt: str = "You are a helpful assistant.",

 max_concurrent: int = 10,

) -> list[dict]:

 """Generate training data from API teacher for black-box distillation."""

 semaphore = asyncio.Semaphore(max_concurrent)

 results = []

 async def call_teacher(prompt):

 async with semaphore:

 response = await client.chat.completions.create(

 model=teacher_model,

 messages=[

 {"role": "system", "content": system_prompt},

 {"role": "user", "content": prompt},

 ],

 temperature=0.7,

 max_tokens=2048,

 )

 return {

 "instruction": prompt,

 "response": response.choices[0].message.content,

 "teacher": teacher_model,

 }

 tasks = [call_teacher(p) for p in prompts]

 results = await asyncio.gather(*tasks, return_exceptions=True)

 # Filter out errors

 valid = [r for r in results if isinstance(r, dict)]

 print(f"Generated {len(valid)}/{len(prompts)} training examples")

 return valid

# Generate dataset

# training_data = asyncio.run(generate_distillation_data(prompts))

# Then fine-tune student model on this data using standard SFT
Code Fragment 16.1.3: Batch API workflow for asynchronous request processing. The JSONL format bundles multiple requests into a single submission. This approach trades latency for significant cost savings.
Warning

Black-box distillation from proprietary API models raises important licensing considerations. Most API providers (OpenAI, Anthropic, Google) have terms of service that restrict using their outputs to train competing models. Always review the provider's usage policy before conducting distillation. Open-weight models (Llama, Mistral, Qwen) generally allow distillation, but check their specific licenses. The Llama 3 license, for example, allows derivative works but has specific restrictions for very large deployments.

3. Case Studies in LLM Distillation

3.1 Orca: Learning from Complex Explanations

Microsoft's Orca (2023) demonstrated that small models can dramatically improve by learning not just the teacher's answers but its reasoning process. Orca trained a 13B student on millions of examples from GPT-4 that included detailed chain-of-thought explanations, step-by-step reasoning, and self-correction. The key innovations were: using system prompts to elicit rich explanations from the teacher, curating diverse and challenging prompts, and training the student on the full reasoning trace rather than just the final answer.

3.2 Phi Series: Textbook-Quality Data

Microsoft's Phi models (Phi-1, Phi-1.5, Phi-2, Phi-3) showed that data quality matters more than model size. Rather than distilling on conversational data, the Phi team used GPT-4 to generate "textbook-quality" synthetic training data: carefully structured explanations, worked examples, and exercises across diverse topics. Phi-3 (3.8B parameters) achieves performance competitive with much larger models on reasoning benchmarks, demonstrating that a small model trained on exceptional data can outperform a larger model trained on mediocre data.

3.3 Distilled DeepSeek-R1

DeepSeek distilled their large R1 reasoning model (671B MoE) into a family of smaller dense models (1.5B, 7B, 8B, 14B, 32B, 70B). The distillation process used 800K samples of the R1 teacher's chain-of-thought reasoning traces. The distilled models retain strong mathematical and coding reasoning abilities, with the 32B distilled variant outperforming many larger models on math benchmarks. This demonstrates that reasoning capabilities, which were previously thought to require enormous scale, can be effectively compressed through distillation. Figure 16.1.3 summarizes the key principles shared by successful distillation projects.

Orca (13B) Matched ChatGPT on complex reasoning benchmarks Phi-3 (3.8B) Competitive with Llama-3 8B on MMLU and reasoning tasks DeepSeek-R1 (32B) Strong math reasoning from 671B MoE teacher Common Design Principles 1. Chain-of-thought: train on reasoning traces, not just final answers 2. Data quality: curated, diverse, challenging prompts yield better students 3. Scale of data: hundreds of thousands to millions of teacher examples 4. System prompts: instruct teacher to explain, reason, and show work
Figure 16.1.3: Successful distillation projects share common principles: rich reasoning traces, high-quality data, and diverse prompts.
Key Insight

The single most impactful lesson from distillation research is: distill the reasoning process, not just the answer. When a teacher model generates chain-of-thought explanations, step-by-step solutions, and self-corrections, the student learns much more effectively than from answer-only training data. This is why models like Orca and distilled DeepSeek-R1 dramatically outperform naive distillation approaches that only collect the teacher's final outputs.

4. Small-but-Capable Models

Distillation has enabled a new class of small models that achieve remarkable performance relative to their size. These models demonstrate that the right combination of architecture, training data, and distillation can produce efficient models for deployment on edge devices, mobile platforms, or high-throughput serving scenarios.

4. Small-but-Capable Models Intermediate Comparison
Model FamilySizesKey TechniqueNotable Capability
Phi (Microsoft)1.3B, 2.7B, 3.8B, 14BTextbook-quality synthetic dataStrong reasoning for size
Gemma (Google)2B, 7B, 9B, 27BDistilled from GeminiMultilingual, coding
SmolLM (HF)135M, 360M, 1.7BCurated web + synthetic dataUltra-small deployment
Qwen2.5 (Alibaba)0.5B, 1.5B, 3B, 7B+Multi-stage distillationMath, code, multilingual
Llama 3.2 (Meta)1B, 3BPruning + distillationOn-device, mobile

Beyond Distillation: Training Efficient Small Models Directly

Not all small, capable models are created through distillation. Several complementary approaches produce efficient models by rethinking how small models are trained from scratch or derived from larger ones.

The common thread is the Chinchilla lesson applied at small scale (see Section 06.3): an undertrained 7B model is often worse than a well-trained 2B model. For deployment on edge devices and cost-sensitive applications, these approaches complement the distillation and quantization techniques from Section 09.1.

5. Practical Distillation Pipeline

Here is a complete pipeline that combines data generation from an API teacher with student training, demonstrating the end-to-end black-box distillation workflow. Code Fragment 16.1.3 shows this approach in practice.

# Load and prepare the distillation dataset with teacher-generated labels
# Each example pairs an input with the teacher model's soft predictions
from datasets import Dataset

from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import SFTTrainer, SFTConfig

from peft import LoraConfig

import json

# Step 1: Prepare distillation dataset

# (Assume we have generated data from teacher API)

distillation_data = [

 {

 "messages": [

 {"role": "user", "content": "Explain gradient descent."},

 {"role": "assistant", "content": "Gradient descent is an optimization..."},

 ]

 },

 # ... thousands more examples

]

dataset = Dataset.from_list(distillation_data)

# Step 2: Configure student with LoRA (parameter-efficient)

student_id = "meta-llama/Llama-3.2-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(student_id)

lora_config = LoraConfig(

 r=32, # Higher rank for distillation

 lora_alpha=64,

 target_modules="all-linear",

 task_type="CAUSAL_LM",

)

# Step 3: Train student on teacher-generated data

sft_config = SFTConfig(

 output_dir="./distilled-student",

 num_train_epochs=3,

 per_device_train_batch_size=4,

 gradient_accumulation_steps=4,

 learning_rate=2e-4,

 bf16=True,

 max_seq_length=4096,

 logging_steps=10,

 save_strategy="epoch",

)

trainer = SFTTrainer(

 model=student_id,

 args=sft_config,

 train_dataset=dataset,

 peft_config=lora_config,

)

trainer.train()

print("Distillation complete. Evaluate student against teacher.")
Code Fragment 16.1.4: Load and prepare the distillation dataset with teacher-generated labels
Note

For distillation, use a higher LoRA rank (32-64) than you would for standard fine-tuning. The student needs more capacity to absorb the teacher's knowledge. Also consider training for more epochs (3-5) with a larger and more diverse dataset. Distillation benefits from data scale more than standard fine-tuning because each example conveys information about the teacher's behavior across many dimensions.

6. Distillation Licensing and Usage Policies

Before building a distillation pipeline, you must understand the legal constraints imposed by model providers. Most major LLM providers restrict or prohibit using their model outputs to train competing models. Violating these terms can result in service termination, legal liability, or both. The landscape varies significantly across providers.

6.1 Provider Policies (as of 2026)

6.1 Provider Policies (as of 2026) Comparison
ProviderOutput Usage for TrainingKey Restriction
OpenAIProhibited for competing modelsYou may not use outputs to "develop any artificial intelligence models that compete with our Products and Services." Fine-tuning through OpenAI's own API is permitted.
AnthropicRestrictedUsage policy prohibits using outputs to train models that compete with Anthropic's services. Research and internal use cases should be reviewed against current terms.
Google (Gemini)RestrictedTerms of Service prohibit using Gemini outputs to develop "competing generative AI products." Google Cloud enterprise agreements may include different terms.
Meta (Llama)Permitted with conditionsThe Llama open license permits distillation. However, models with over 700 million monthly active users must request a separate license.
MistralPermissiveApache 2.0 licensed models have no output restrictions. Commercial models accessed through API may have separate terms.
DeepSeekPermissiveMIT license places no restrictions on output usage for training.

The practical implication is clear: if you plan to distill knowledge from a proprietary API into your own model, verify that the provider's terms permit it. For many production use cases, starting with an open-weight teacher (Llama, Mistral, DeepSeek) avoids licensing uncertainty entirely. When using a proprietary API, the safest approach is to use the provider's own fine-tuning service rather than extracting outputs for external training.

Warning

Licensing terms change frequently. Always check the current Terms of Service before starting a distillation project. A policy that was permissive six months ago may have been revised. Enterprise agreements sometimes override standard terms, so consult your legal team if the project involves proprietary API outputs.

7. Speculative Distillation

Speculative decoding uses a small "draft" model to propose multiple tokens at once, which a larger model then verifies in a single forward pass. Speculative distillation trains the draft model specifically to mimic the larger model's token distribution, improving the acceptance rate (how often the large model agrees with the draft) and thus the overall throughput. This technique turns distillation into a serving-time optimization rather than just a training-time technique. Figure 16.1.4 shows this propose-and-verify workflow.

Draft Model (1B, distilled) t1 t2 t3 t4 Propose 4 tokens (fast) Target Model (70B, full) Verify all at once (1 pass) Result: 3 tokens accepted in 1 forward pass = ~3x speedup
Figure 16.1.4: A distilled draft model proposes tokens that the target model verifies in parallel, multiplying throughput.

8. Distillation for Reasoning: Chain-of-Thought Preservation

Standard distillation transfers a teacher's input-output mapping to a smaller student, but this approach discards the teacher's intermediate reasoning. When a 70B model solves a math problem, it internally constructs a chain of logical steps before arriving at the answer. A student trained only on final answers learns what to output but not how to reason, limiting its ability to generalize to novel problems. Chain-of-thought (CoT) distillation addresses this by training the student to reproduce the teacher's explicit reasoning traces alongside the final answers.

The procedure is straightforward. First, the teacher model generates solutions to a large set of training problems using chain-of-thought prompting, producing step-by-step reasoning followed by a final answer. Second, these complete reasoning traces become the training targets for the student. The student learns to generate the full chain of thought, not just the conclusion. This approach was central to the Distilled DeepSeek-R1 models, where 800K reasoning traces from the R1 teacher were used to fine-tune Qwen and Llama base models. The resulting students (1.5B to 70B parameters) achieved reasoning performance competitive with models several times their size.

CoT distillation introduces a design choice: should the student be trained on the teacher's complete reasoning trace, or should the traces be filtered and curated? In practice, filtering is essential. Teacher models sometimes produce redundant steps, circular reasoning, or incorrect intermediate conclusions that happen to reach the right final answer. A robust pipeline filters traces by (1) verifying the final answer against a ground truth, (2) checking that intermediate steps are logically consistent, and (3) removing excessively long or repetitive reasoning chains. The Orca series demonstrated that curating high-quality explanations (with explicit reasoning structure: "First, ..., Therefore, ..., Finally, ...") produces better students than training on raw, unfiltered teacher outputs.

A key finding from reasoning distillation research is that the student's reasoning format does not need to match the teacher's. A teacher that uses verbose, natural-language reasoning can be distilled into a student that produces compact, structured reasoning (numbered steps or symbolic notation) as long as the training data is reformatted accordingly. This flexibility allows practitioners to optimize the student's reasoning format for their deployment constraints: verbose reasoning for user-facing explanations, compact reasoning for latency-sensitive applications where the chain of thought is used internally but not shown to the user.

Key Insight

The critical success factor in CoT distillation is the quality and diversity of the reasoning traces, not the size of the dataset. Research consistently shows that 10K high-quality, verified reasoning traces outperform 100K unfiltered traces. Invest in building a pipeline that generates, verifies, and curates reasoning data from the teacher before scaling up the dataset size.

Self-Check
Q1: Why do soft targets (teacher probabilities with temperature) provide more useful training signal than hard labels?
Show Answer
Soft targets encode the teacher's uncertainty and the relative similarity between possible outputs. When the teacher assigns 0.6 probability to "happy" and 0.2 to "glad," this reveals that these words are semantically related. Hard labels (one-hot vectors) provide only the correct answer and no information about the relationships between alternatives. This "dark knowledge" in the non-top predictions helps the student learn richer representations.
Q2: What is the difference between white-box and black-box distillation, and which produces higher-quality students?
Show Answer
White-box distillation has access to the teacher's full logit distribution and uses KL divergence loss. Black-box distillation only has access to the teacher's text outputs and uses standard cross-entropy loss. White-box generally produces higher-quality students because the full probability distribution contains far more information than a single output token. However, black-box is the only option when the teacher is an API model without logit access.
Q3: What was the key insight from Microsoft's Orca model that improved distillation quality?
Show Answer
Orca demonstrated that training the student on the teacher's full reasoning process (chain-of-thought explanations, step-by-step solutions, self-corrections) is far more effective than training on just the final answers. By using system prompts to elicit detailed explanations from GPT-4, Orca created training data that taught the student how to reason, not just what to answer.
Q4: Why does the distillation loss include a T-squared scaling factor?
Show Answer
When temperature T is applied to the softmax, it reduces the magnitude of the gradients by a factor of T². Multiplying the KL divergence loss by T² compensates for this reduction, ensuring that the gradient magnitudes are consistent regardless of the temperature value. Without this scaling, higher temperatures would produce vanishingly small gradients and extremely slow training.
Q5: How does speculative distillation differ from standard distillation in terms of its goal?
Show Answer
Standard distillation aims to create a small model that replaces the large model entirely for inference. Speculative distillation creates a small "draft" model that works alongside the large model during inference, proposing token candidates that the large model verifies in parallel. The goal is not replacement but acceleration: by training the draft model to closely match the target's distribution, more proposed tokens are accepted per verification step, increasing throughput by 2-4x while maintaining the exact output quality of the large model.
Tip: Distill on Task-Specific Data

For knowledge distillation, use data from your target domain rather than generic pretraining data. A student model trained on 50K domain-specific examples from the teacher often outperforms one trained on 500K generic examples.

Key Takeaways

Lab: Distill a Teacher Model into a Student

Duration: ~60 minutes Advanced

Objective

Implement knowledge distillation by training a small student model (DistilGPT-2) to mimic the output distribution of a larger teacher model (GPT-2 Medium), using KL divergence on soft targets with temperature scaling.

What You'll Practice

  • Computing soft targets with temperature-scaled softmax
  • Implementing the KL divergence distillation loss
  • Balancing hard-label and soft-label loss components
  • Comparing student performance before and after distillation

Setup

The following cell installs the required packages and configures the environment for this lab.

pip install transformers torch datasets tqdm
Teacher: 354,823,168 params
Student: 81,912,576 params
Compression: 4.3x
Code Fragment 16.1.5: Code example

Steps

Step 1: Load teacher and student models

Load a larger teacher (frozen) and a smaller student (trainable).

# Load teacher (gpt2-medium, frozen) and student (gpt2, trainable).
# The student will learn to mimic the teacher's output distribution.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

teacher = AutoModelForCausalLM.from_pretrained("gpt2-medium").to(device)
teacher.eval()
for p in teacher.parameters():
 p.requires_grad = False

student = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

t_params = sum(p.numel() for p in teacher.parameters())
s_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {t_params:,} params")
print(f"Student: {s_params:,} params")
print(f"Compression: {t_params / s_params:.1f}x")
Code Fragment 16.1.6: Load teacher (gpt2-medium, frozen) and student (gpt2, trainable).
Hint

Both GPT-2 variants share the same tokenizer. The teacher must be frozen (eval mode, requires_grad=False) so only the student gets updated during training.

Step 2: Implement the distillation loss

Write the core loss that combines KL divergence on soft targets with cross-entropy on hard targets.

# Distillation loss: KL divergence on temperature-scaled soft targets
# blended with cross-entropy on hard (ground-truth) labels.
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels,
 temperature=3.0, alpha=0.5):
 """Combined distillation loss with soft and hard targets."""
 # TODO: Compute soft targets from teacher (temperature-scaled softmax)
 # TODO: Compute soft predictions from student (temperature-scaled log-softmax)
 # TODO: KL divergence between student and teacher distributions
 # TODO: Standard cross-entropy on hard labels
 # TODO: Combine: alpha * T^2 * KL + (1-alpha) * CE
 pass

# Quick test
b, s, v = 2, 10, 50257
loss = distillation_loss(torch.randn(b,s,v), torch.randn(b,s,v),
 torch.randint(0, v, (b,s)))
print(f"Test loss: {loss.item():.4f}")
Test loss: 11.2438
Code Fragment 16.1.7: Distillation loss: KL divergence on temperature-scaled soft targets
Hint

Soft targets: F.softmax(teacher_logits / T, dim=-1). KL loss: F.kl_div(F.log_softmax(student_logits / T, dim=-1), soft_targets, reduction='batchmean'). Multiply KL by T² to keep gradients balanced across temperatures.

Step 3: Run the distillation training loop

Load text data and train the student to mimic the teacher's outputs.

# Training loop: forward both teacher and student on each batch,
# compute the combined distillation loss, and update student only.
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
dataset = dataset.filter(lambda x: len(x['text'].strip()) > 50)

def tokenize(examples):
 return tokenizer(examples['text'], truncation=True,
 max_length=128, padding='max_length')

tokenized = dataset.select(range(2000)).map(
 tokenize, batched=True, remove_columns=['text'])
tokenized.set_format('torch')
loader = DataLoader(tokenized, batch_size=8, shuffle=True)

optimizer = torch.optim.AdamW(student.parameters(), lr=5e-5)
student.train()

for epoch in range(2):
 total_loss = 0
 for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
 input_ids = batch['input_ids'].to(device)
 attn = batch['attention_mask'].to(device)

 student_out = student(input_ids, attention_mask=attn)
 with torch.no_grad():
 teacher_out = teacher(input_ids, attention_mask=attn)

 loss = distillation_loss(
 student_out.logits, teacher_out.logits, input_ids)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()
 total_loss += loss.item()

 print(f"Epoch {epoch+1} avg loss: {total_loss / len(loader):.4f}")
Epoch 1: 100%|██████████| 250/250 [01:42<00:00, 2.44it/s]
Epoch 1 avg loss: 4.7261
Epoch 2: 100%|██████████| 250/250 [01:38<00:00, 2.53it/s]
Epoch 2 avg loss: 3.8934
Code Fragment 16.1.8: Training loop: forward both teacher and student on each batch,
Hint

The teacher forward pass is wrapped in torch.no_grad() since we never update its weights. If you run out of memory, reduce batch size to 4.

Step 4: Evaluate with perplexity

Compare perplexity for the original student, distilled student, and teacher.

# Evaluate perplexity: compare original student, distilled student,
# and teacher to quantify how much knowledge transferred.
import math

def compute_perplexity(model, eval_loader):
 model.eval()
 total_loss, total_tokens = 0, 0
 with torch.no_grad():
 for batch in eval_loader:
 ids = batch['input_ids'].to(device)
 attn = batch['attention_mask'].to(device)
 out = model(ids, attention_mask=attn, labels=ids)
 total_loss += out.loss.item() * attn.sum().item()
 total_tokens += attn.sum().item()
 return math.exp(total_loss / total_tokens)

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
eval_data = eval_data.filter(lambda x: len(x['text'].strip()) > 50)
eval_tok = eval_data.select(range(500)).map(
 tokenize, batched=True, remove_columns=['text'])
eval_tok.set_format('torch')
eval_loader = DataLoader(eval_tok, batch_size=8)

original = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
print(f"Teacher PPL: {compute_perplexity(teacher, eval_loader):.2f}")
print(f"Original student PPL: {compute_perplexity(original, eval_loader):.2f}")
print(f"Distilled student PPL: {compute_perplexity(student, eval_loader):.2f}")
Teacher PPL: 24.87
Original student PPL: 46.13
Distilled student PPL: 36.41
Code Fragment 16.1.9: Evaluate perplexity: compare original student, distilled student,
Hint

Typical results: teacher ~25, original student ~45, distilled student ~35. The distilled student should show a 10 to 20% perplexity improvement while remaining 4x smaller than the teacher.

Expected Output

  • Distilled student perplexity improves 10 to 20% over the original student
  • The student remains ~4x smaller than the teacher
  • Training loss should decrease steadily over 2 epochs

Stretch Goals

  • Experiment with temperatures (T=1, 3, 5, 10) and plot how each affects the distilled student's perplexity
  • Try different alpha values (0.3, 0.5, 0.7) and find the optimal balance
  • Implement intermediate layer distillation: also match hidden states between teacher and student layers
Complete Solution
# Complete distillation lab: load teacher/student, train with KL
# divergence loss, evaluate perplexity before and after distillation.
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch, torch.nn.functional as F, math
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
teacher = AutoModelForCausalLM.from_pretrained("gpt2-medium").to(device)
teacher.eval()
for p in teacher.parameters(): p.requires_grad = False
student = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def distillation_loss(s_logits, t_logits, labels, T=3.0, alpha=0.5):
 soft_t = F.softmax(t_logits / T, dim=-1)
 soft_s = F.log_softmax(s_logits / T, dim=-1)
 kl = F.kl_div(soft_s, soft_t, reduction='batchmean')
 ce = F.cross_entropy(s_logits.view(-1, s_logits.size(-1)), labels.view(-1))
 return alpha * (T**2) * kl + (1 - alpha) * ce

ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
ds = ds.filter(lambda x: len(x['text'].strip()) > 50)
def tok(ex): return tokenizer(ex['text'], truncation=True, max_length=128, padding='max_length')
train_tok = ds.select(range(2000)).map(tok, batched=True, remove_columns=['text'])
train_tok.set_format('torch')
loader = DataLoader(train_tok, batch_size=8, shuffle=True)

opt = torch.optim.AdamW(student.parameters(), lr=5e-5)
student.train()
for epoch in range(2):
 total = 0
 for batch in tqdm(loader, desc=f"Epoch {epoch+1}"):
 ids, attn = batch['input_ids'].to(device), batch['attention_mask'].to(device)
 s_out = student(ids, attention_mask=attn)
 with torch.no_grad(): t_out = teacher(ids, attention_mask=attn)
 loss = distillation_loss(s_out.logits, t_out.logits, ids)
 loss.backward(); opt.step(); opt.zero_grad(); total += loss.item()
 print(f"Epoch {epoch+1}: {total/len(loader):.4f}")

def ppl(model, dl):
 model.eval(); tl, tt = 0, 0
 with torch.no_grad():
 for b in dl:
 i, a = b['input_ids'].to(device), b['attention_mask'].to(device)
 o = model(i, attention_mask=a, labels=i)
 tl += o.loss.item()*a.sum().item(); tt += a.sum().item()
 return math.exp(tl/tt)

ev = load_dataset("wikitext","wikitext-2-raw-v1",split="validation")
ev = ev.filter(lambda x: len(x['text'].strip())>50)
ev_t = ev.select(range(500)).map(tok, batched=True, remove_columns=['text'])
ev_t.set_format('torch')
el = DataLoader(ev_t, batch_size=8)
orig = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
print(f"Teacher: {ppl(teacher,el):.2f}")
print(f"Original: {ppl(orig,el):.2f}")
print(f"Distilled: {ppl(student,el):.2f}")
Epoch 1: 100%|██████████| 250/250 [01:42<00:00, 2.44it/s]
Epoch 1: 4.7261
Epoch 2: 100%|██████████| 250/250 [01:38<00:00, 2.53it/s]
Epoch 2: 3.8934
Teacher: 24.87
Original: 46.13
Distilled: 36.41
Code Fragment 16.1.10: Complete distillation lab: load teacher/student, train with KL
Research Frontier

The success of distillation in creating small reasoning models (like DeepSeek-R1-Distill and Phi-4-mini) has demonstrated that chain-of-thought capabilities can transfer from large to small models more effectively than previously assumed. Research on progressive distillation applies multiple rounds of compression, gradually shrinking models while preserving capabilities at each stage.

An open frontier is distilling not just model outputs but internal representations and reasoning patterns, with techniques like feature distillation showing promise for preserving capabilities that output matching alone misses.

Exercises

Exercise 16.1.1: Distillation fundamentals Conceptual

Explain the difference between hard labels and soft labels (logits) in knowledge distillation. Why do soft labels transfer more knowledge than hard labels?

Answer Sketch

Hard labels are one-hot vectors (e.g., [0, 0, 1, 0] for class 3). Soft labels are the teacher model's full probability distribution (e.g., [0.05, 0.10, 0.80, 0.05]). Soft labels convey inter-class relationships: a '7' that looks like a '1' gets a higher probability for '1' than for '8'. This dark knowledge teaches the student model about similarity structures that hard labels completely discard. The temperature parameter controls how much this dark knowledge is emphasized.

Exercise 16.1.2: Temperature in distillation Conceptual

What is the role of the temperature parameter T in distillation? What happens at T=1, T=5, and T=20?

Answer Sketch

Temperature T softens the teacher's probability distribution: p_i = exp(z_i/T) / sum(exp(z_j/T)). At T=1: standard softmax, probabilities are peaked on the correct class. At T=5: distribution is smoother, revealing the teacher's uncertainty and inter-class similarity. At T=20: very flat distribution, nearly uniform, which may lose too much signal. Typical values: T=2 to T=5. Higher T transfers more dark knowledge but may dilute the primary signal. The student's loss is a weighted sum of distillation loss (soft targets) and standard cross-entropy (hard targets).

Exercise 16.1.3: Distillation pipeline Coding

Write the key components of a distillation training loop: forward pass through teacher (no grad), forward pass through student, combined loss computation with temperature scaling.

Answer Sketch

Teacher: with torch.no_grad(): teacher_logits = teacher(input_ids).logits. Student: student_logits = student(input_ids).logits. Soft loss: soft_loss = F.kl_div(F.log_softmax(student_logits/T, dim=-1), F.softmax(teacher_logits/T, dim=-1), reduction='batchmean') * T*T. Hard loss: hard_loss = F.cross_entropy(student_logits, labels). Combined: loss = alpha * soft_loss + (1-alpha) * hard_loss. Typical alpha=0.5, T=3.

Exercise 16.1.4: On-policy vs. off-policy distillation Analysis

Compare on-policy distillation (student generates, teacher scores) with off-policy distillation (teacher generates, student learns). When does each approach produce better results?

Answer Sketch

Off-policy: the teacher generates responses that the student learns to imitate via SFT. Simple to implement but the student only sees the teacher's distribution, not its own mistakes. On-policy: the student generates responses, the teacher provides token-level or sequence-level feedback, and the student learns from its own errors. On-policy is more compute-intensive but typically produces stronger students because the training data matches the student's own distribution. On-policy is preferred for larger distribution gaps between teacher and student.

Exercise 16.1.5: Student architecture selection Coding

You want to distill a 70B teacher into a smaller student. Compare distilling into a 7B model (same architecture, fewer layers) versus a 3B model (different architecture). What factors determine the choice?

Answer Sketch

Consider: (1) Target inference budget: 3B is ~2.3x faster and uses ~2.3x less memory than 7B. (2) Quality ceiling: 7B retains more of the teacher's capability (larger models are better students). (3) Task complexity: simple tasks (classification, extraction) work with 3B; complex reasoning benefits from 7B. (4) Architecture compatibility: same-architecture distillation can use layer-by-layer matching losses in addition to output-level distillation. Choose 7B for quality-critical tasks, 3B for cost/latency-critical deployment.

What Comes Next

In the next section, Section 16.2: Model Merging & Composition, we explore model merging and composition, combining multiple fine-tuned models without additional training.

References & Further Reading
Foundational Papers

Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. NIPS Deep Learning Workshop.

The paper that introduced knowledge distillation using soft targets from a teacher network. Essential reading for understanding how temperature scaling on softmax outputs transfers richer information than hard labels alone.

📄 Paper

West, P., Bhatt, S., Lu, X., Vieira, T., Choi, Y., & Fisher, J. (2022). Symbolic Knowledge Distillation: from General Language Models to Commonsense Models. NAACL 2022.

Demonstrates distillation of structured commonsense knowledge from large LMs into smaller specialized models. Shows how symbolic representations can bridge teacher and student architectures.

📄 Paper
LLM Distillation Methods

Gu, Y., Dong, L., Wei, F., & Huang, M. (2024). MiniLLM: Knowledge Distillation of Large Language Models. ICLR 2024.

Proposes reverse KL divergence for LLM distillation, avoiding the mode-covering problem of standard KL. Practitioners working on distilling autoregressive models should start here.

📄 Paper

Agarwal, R., Vieillard, N., Stanczyk, P., Ranzato, M., & Geist, M. (2024). On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes. ICLR 2024.

Introduces on-policy distillation where the student learns from its own generated outputs, corrected by the teacher. Addresses the train/inference mismatch that plagues offline distillation approaches.

📄 Paper

Hsieh, C.-Y., Li, C.-L., Yeh, C.-K., et al. (2023). Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes. ACL 2023.

Shows that extracting rationales from a teacher and using them as additional training signals lets student models outperform teachers that are 700x larger. A key result for practical chain-of-thought distillation.

📄 Paper
Quantization & Compression

Kim, S., Gholami, A., Yao, Z., Mahoney, M. W., & Keutzer, K. (2021). I-BERT: Integer-only BERT Quantization. ICML 2021.

Presents integer-only quantization for transformer models, enabling efficient deployment on hardware without floating-point support. Useful context for understanding how distillation and quantization complement each other.

📄 Paper