"A proof is a proof. What kind of a proof? It's a proof. A proof is a proof, and when you have a good proof, it's because it's proven."
Chinchilla, Tautologically Rigorous AI Agent
Prerequisites
This section builds on the reasoning model architectures from Section 8.2, the RL training methods from Section 8.3 (especially RLVR and process reward models), and the evaluation framework from Section 8.5. Familiarity with formal logic or proof assistants (Lean, Isabelle, Coq) is helpful but not required; the section introduces the necessary concepts.
Formal theorem proving represents the gold standard for verifiable reasoning. Unlike natural language chain-of-thought, where each step might contain subtle errors, a formal proof is checked mechanically by a proof assistant: every step is either valid or rejected. This section explores how LLMs are being combined with proof assistants like Lean 4, Isabelle, and Metamath to produce reasoning that is not just plausible but provably correct. The key insight is that proof assistants provide a perfect verifier, eliminating the reward model bottleneck that limits test-time compute scaling for informal reasoning (as discussed in Section 8.5). When your verifier has zero false positives, more search always helps.
1. The Intersection of LLMs and Formal Logic
Natural language reasoning and formal theorem proving occupy opposite ends of a spectrum. Natural language is flexible, expressive, and ambiguous. Formal logic is rigid, precise, and mechanically verifiable. The emerging research program at their intersection asks: can LLMs serve as the "intuition engine" that guides formal proof search, while proof assistants serve as the "verification engine" that guarantees correctness?
This combination addresses a fundamental limitation of LLM reasoning. When an LLM produces a chain-of-thought solution to a math problem, each step is generated autoregressively with no guarantee of logical validity. The model might skip steps, make sign errors, or apply theorems incorrectly. Process reward models (from Section 8.3) attempt to catch these errors, but they are themselves imperfect. In formal theorem proving, the proof assistant checks every step with mathematical certainty. A completed formal proof is correct by construction.
The challenge is that formal proofs are much harder to write than informal ones. A single page of informal mathematics might expand to hundreds of lines of Lean 4 code, with each logical step made fully explicit. LLMs can bridge this gap by translating mathematical intuition into formal tactic sequences, effectively serving as a "copilot" for the proof assistant.
The formal proving setup creates an ideal reinforcement learning environment. The proof assistant provides a binary, ground-truth reward signal (proof accepted or rejected) with no reward hacking possible. This is qualitatively different from training on informal math benchmarks, where the reward model can be gamed. Combined with self-play and MCTS (as in Section 8.5), this creates a feedback loop where the model generates proof attempts, the verifier filters for correctness, and the model improves from its own successes.
2. LeanDojo: Data Extraction and Retrieval-Augmented Proving
LeanDojo (Yang et al., 2023) addresses a foundational bottleneck in LLM-based theorem proving: the lack of high-quality training data extracted from proof assistants. Lean 4's mathlib library contains over 100,000 theorems, but extracting the proof states, tactic applications, and premise dependencies in a machine-readable format requires deep integration with the Lean compiler. LeanDojo provides this integration, along with a benchmark and a retrieval-augmented prover called ReProver.
2.1 Data Extraction from Lean
LeanDojo instruments the Lean 4 compiler to capture proof states at every tactic step. For each tactic application, it records the goal state before and after, the tactic text, and all premises (previously proven lemmas) referenced by the tactic. This produces a dataset of (state, tactic, premises) triples that can be used for supervised fine-tuning.
# Conceptual illustration of LeanDojo data extraction
# LeanDojo extracts proof states from Lean 4's mathlib
# Each entry captures a proof step with full context
from dataclasses import dataclass
from typing import List
@dataclass
class ProofState:
"""A single proof state extracted from Lean 4."""
goal: str # Current proof goal in Lean syntax
hypotheses: List[str] # Available hypotheses
tactic: str # Tactic applied at this step
premises: List[str] # Lemmas referenced by the tactic
result_goal: str # Goal after tactic application
# Example extracted from mathlib
example_state = ProofState(
goal="forall (n : Nat), n + 0 = n",
hypotheses=[],
tactic="intro n; induction n with | zero => rfl | succ n ih => simp [Nat.add_succ, ih]",
premises=["Nat.add_succ", "Nat.succ_eq_add_one"],
result_goal="no goals" # proof complete
)
print(f"Goal: {example_state.goal}")
print(f"Tactic: {example_state.tactic}")
print(f"Premises used: {example_state.premises}")
2.2 ReProver: Retrieval-Augmented Theorem Proving
A key challenge in automated theorem proving is premise selection: given a proof goal, which of the 100,000+ available lemmas are relevant? Human mathematicians solve this by memory and intuition. ReProver solves it with retrieval. Given a proof state, ReProver encodes the goal with a transformer and retrieves the most relevant premises from the full mathlib library using dense retrieval (similar to the embedding-based retrieval in Section 18.1). The retrieved premises are prepended to the prompt for the tactic generator.
# ReProver-style retrieval-augmented proving pipeline
# Step 1: Encode the proof state
# Step 2: Retrieve relevant premises from mathlib
# Step 3: Generate tactic conditioned on state + premises
import torch
import torch.nn.functional as F
class PremiseRetriever:
"""Dense retrieval of relevant lemmas for a proof goal."""
def __init__(self, encoder, premise_embeddings, premise_names):
self.encoder = encoder # Transformer encoder
self.premise_db = premise_embeddings # Pre-computed embeddings
self.premise_names = premise_names # Lemma names
def retrieve(self, goal_text: str, top_k: int = 20):
"""Retrieve the top-k most relevant premises for a goal."""
# Encode the proof goal
goal_embedding = self.encoder.encode(goal_text)
# Compute similarity against all premises
similarities = F.cosine_similarity(
goal_embedding.unsqueeze(0),
self.premise_db,
dim=-1
)
# Return top-k premises
top_indices = similarities.argsort(descending=True)[:top_k]
return [
(self.premise_names[i], similarities[i].item())
for i in top_indices
]
class ReProverPipeline:
"""Simplified ReProver: retrieve premises, then generate tactic."""
def __init__(self, retriever, tactic_generator):
self.retriever = retriever
self.generator = tactic_generator
def prove_step(self, proof_state: str) -> str:
# Step 1: Retrieve relevant premises
premises = self.retriever.retrieve(proof_state, top_k=20)
premise_text = "\n".join(
f"-- {name} (sim={score:.3f})" for name, score in premises
)
# Step 2: Generate tactic conditioned on state + premises
prompt = f"Relevant lemmas:\n{premise_text}\n\nGoal:\n{proof_state}\n\nTactic:"
tactic = self.generator.generate(prompt)
return tactic
3. Formal Mathematics Benchmarks
Evaluating LLM-based theorem provers requires benchmarks with formal statements that can be checked by a proof assistant. Three benchmarks have become standard in this space: miniF2F for competition-level mathematics, ProofNet for undergraduate-level problems, and the broader LeanDojo benchmark derived from mathlib.
3.1 miniF2F: Cross-System Competition Mathematics
miniF2F (Zheng et al., 2022) contains 488 formalized mathematical statements drawn from AMC, AIME, and IMO competitions. Each problem is available in multiple proof assistant languages (Lean 4, Isabelle, Metamath), enabling direct comparison across systems. The statements span algebra, number theory, and analysis at a difficulty level comparable to strong undergraduate or early graduate mathematics.
The cross-system aspect of miniF2F is particularly valuable. Because the same mathematical statement is formalized in multiple systems, researchers can isolate the effect of the proof assistant (Lean vs. Isabelle) from the effect of the LLM and search strategy. In practice, Lean 4 has become the dominant target because of its larger training corpus (mathlib) and better tooling.
3.2 ProofNet: Undergraduate Formal Mathematics
ProofNet (Azerbayev et al., 2023) fills the gap between simple arithmetic verification and competition-level problems. It contains 371 formalized exercises from standard undergraduate textbooks covering real analysis, abstract algebra, topology, and linear algebra. Each problem is formalized in Lean 4 with a natural language statement paired with its formal counterpart.
ProofNet is designed to test autoformalization: the ability to translate a natural language mathematical statement into a formal one that a proof assistant can check. This is a prerequisite skill for any system that aims to assist human mathematicians, since most mathematical communication happens in natural language.
# Evaluation framework for formal proving benchmarks
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import subprocess
import json
@dataclass
class FormalProvingResult:
"""Result of attempting to prove a formal statement."""
statement_id: str
proved: bool
num_attempts: int
proof_text: Optional[str]
time_seconds: float
@dataclass
class BenchmarkEvaluation:
"""Aggregate evaluation across a formal proving benchmark."""
results: List[FormalProvingResult] = field(default_factory=list)
def pass_at_k(self, k: int) -> float:
"""Compute pass@k: fraction proved within k attempts."""
proved = sum(
1 for r in self.results if r.proved and r.num_attempts <= k
)
return proved / len(self.results) if self.results else 0.0
def premise_retrieval_accuracy(
self, predicted: Dict[str, List[str]],
gold: Dict[str, List[str]]
) -> float:
"""Evaluate premise retrieval quality (recall@k)."""
recalls = []
for stmt_id, gold_premises in gold.items():
pred_premises = set(predicted.get(stmt_id, []))
if gold_premises:
recall = len(
pred_premises.intersection(gold_premises)
) / len(gold_premises)
recalls.append(recall)
return sum(recalls) / len(recalls) if recalls else 0.0
def summary(self) -> dict:
return {
"total": len(self.results),
"proved": sum(1 for r in self.results if r.proved),
"pass@1": self.pass_at_k(1),
"pass@10": self.pass_at_k(10),
"pass@100": self.pass_at_k(100),
"avg_time": sum(r.time_seconds for r in self.results) / max(len(self.results), 1),
}
# Example evaluation
eval_run = BenchmarkEvaluation(results=[
FormalProvingResult("minif2f_001", True, 3, "intro n; omega", 12.5),
FormalProvingResult("minif2f_002", True, 1, "simp [add_comm]", 2.1),
FormalProvingResult("minif2f_003", False, 100, None, 180.0),
FormalProvingResult("minif2f_004", True, 47, "ring", 95.3),
])
print(json.dumps(eval_run.summary(), indent=2))
| Benchmark | Size | Difficulty | Systems | Primary Use |
|---|---|---|---|---|
| miniF2F | 488 problems | Competition (AMC/AIME/IMO) | Lean, Isabelle, Metamath | Cross-system comparison |
| ProofNet | 371 problems | Undergraduate | Lean 4 | Autoformalization |
| LeanDojo Bench | ~98,000 theorems | Mixed (mathlib) | Lean 4 | Premise selection, tactic prediction |
4. AlphaProof: Reinforcement Learning for Formal Mathematics
DeepMind's AlphaProof (2024) demonstrated that combining LLMs with reinforcement learning and formal verification can solve competition-level mathematics. At the 2024 International Mathematical Olympiad (IMO), AlphaProof solved 4 out of 6 problems, achieving a score equivalent to a silver medal. This was a landmark result: the first time an AI system performed competitively on the full IMO, widely considered one of the hardest mathematical competitions in the world.
AlphaProof's architecture combines three components. First, a language model (Gemini) translates informal problem statements into formal Lean 4 specifications. Second, a value network (trained via self-play) estimates the probability of completing a proof from any given state. Third, an MCTS-style search (building on the AlphaZero framework) explores the space of possible tactic applications, guided by the value network and a policy network that suggests promising tactics.
The training loop follows a self-play paradigm. AlphaProof generates proof attempts for a large library of formalized problems, the Lean verifier checks which proofs are valid, and the model is updated via reinforcement learning on the successful proofs. Over many iterations, the model learns increasingly sophisticated proving strategies. Crucially, because the verifier is perfect (a valid Lean proof is mathematically correct by definition), there is no reward hacking: every proof the model learns from is genuinely correct.
# AlphaProof-style self-play training loop (conceptual)
# Combines MCTS search with formal verification
from dataclasses import dataclass
from typing import List, Optional, Tuple
import random
@dataclass
class ProofNode:
"""A node in the proof search tree."""
state: str # Current Lean proof state
tactic: str # Tactic that led to this state
visits: int = 0
value: float = 0.0 # Estimated probability of proof completion
children: list = None
def __post_init__(self):
if self.children is None:
self.children = []
class AlphaProofSearch:
"""MCTS-guided proof search with formal verification."""
def __init__(self, policy_net, value_net, lean_verifier):
self.policy = policy_net # Suggests tactics
self.value = value_net # Estimates proof completion prob
self.verifier = lean_verifier # Lean 4 type checker
self.c_puct = 1.5 # Exploration constant
def search(self, root_goal: str, simulations: int = 800) -> Optional[str]:
"""Run MCTS to find a proof for the given goal."""
root = ProofNode(state=root_goal, tactic="")
for _ in range(simulations):
# Selection: traverse tree using UCB
node, path = self._select(root)
# Expansion: generate candidate tactics
tactics = self.policy.suggest_tactics(node.state, k=32)
for tactic in tactics:
# Apply tactic in Lean and get resulting state
result = self.verifier.apply_tactic(node.state, tactic)
if result.is_proof_complete:
return self._extract_proof(path, tactic)
if result.is_valid:
child = ProofNode(
state=result.new_goal, tactic=tactic
)
child.value = self.value.estimate(result.new_goal)
node.children.append(child)
# Backpropagation: update values along the path
self._backpropagate(path)
return None # No proof found within budget
def _select(self, root: ProofNode) -> Tuple[ProofNode, list]:
"""Select a leaf node using UCB1."""
node = root
path = [node]
while node.children:
node = max(node.children, key=lambda c: self._ucb(c, node))
path.append(node)
return node, path
def _ucb(self, child: ProofNode, parent: ProofNode) -> float:
exploitation = child.value
exploration = self.c_puct * (parent.visits ** 0.5) / (1 + child.visits)
return exploitation + exploration
def _backpropagate(self, path: list):
for node in reversed(path):
node.visits += 1
def _extract_proof(self, path: list, final_tactic: str) -> str:
tactics = [n.tactic for n in path if n.tactic] + [final_tactic]
return "\n".join(tactics)
AlphaProof's architecture is not publicly available in full detail. The description here is based on DeepMind's blog post and the clear parallels with AlphaZero. The key architectural decisions (MCTS + value network + formal verifier) are well established; the proprietary details concern the specific model architecture, training data, and search hyperparameters. Reproducing the full system would require significant compute, but the conceptual framework can be applied at smaller scale using open tools like LeanDojo and ReProver.
5. Dataset Extraction and Proof Assistant Workflows
Building a formal proving system requires extracting structured data from proof assistant libraries. The workflow involves three stages: compiling the library to capture proof states, parsing tactic applications into (state, action, result) triples, and splitting data to avoid leakage between train and test sets.
# Lean 4 data extraction workflow (simplified)
# In practice, LeanDojo handles the Lean compiler integration
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import List
@dataclass
class TacticStep:
"""One step in a Lean proof, extracted from the compiler trace."""
theorem_name: str
step_index: int
goal_before: str
tactic_text: str
goal_after: str
premises_used: List[str]
def extract_proof_data(lean_project_path: str) -> List[TacticStep]:
"""
Extract tactic-level proof data from a Lean 4 project.
In practice, this uses LeanDojo's tracing infrastructure
which instruments the Lean compiler to capture proof states.
"""
steps = []
# LeanDojo traces through the Lean compilation process
# and captures every tactic application with its context
# Pseudocode for the extraction loop:
traced_files = trace_lean_project(lean_project_path)
for file_data in traced_files:
for theorem in file_data["theorems"]:
for i, step in enumerate(theorem["proof_steps"]):
steps.append(TacticStep(
theorem_name=theorem["name"],
step_index=i,
goal_before=step["goal_before"],
tactic_text=step["tactic"],
goal_after=step["goal_after"],
premises_used=step["premises"],
))
return steps
def split_by_theorem(
steps: List[TacticStep], test_fraction: float = 0.1
) -> dict:
"""
Split data by theorem (not by step) to prevent leakage.
All steps from one theorem go to the same split.
"""
theorem_names = list(set(s.theorem_name for s in steps))
# Deterministic shuffle for reproducibility
theorem_names.sort()
split_point = int(len(theorem_names) * (1 - test_fraction))
train_theorems = set(theorem_names[:split_point])
train_steps = [s for s in steps if s.theorem_name in train_theorems]
test_steps = [s for s in steps if s.theorem_name not in train_theorems]
return {"train": train_steps, "test": test_steps}
# Example usage
print("Extraction pipeline stages:")
print(" 1. Instrument Lean compiler with LeanDojo tracing")
print(" 2. Compile mathlib to capture all proof states")
print(" 3. Parse (goal, tactic, premises) triples")
print(" 4. Split by theorem to avoid train/test leakage")
print(" 5. Build premise retrieval index over all lemma statements")
6. RL and Self-Play Approaches to Formal Reasoning
The formal proving setting is uniquely well suited to reinforcement learning because the proof assistant provides a perfect reward signal. Three RL paradigms have been applied to formal theorem proving:
- Expert iteration (ExIt): The model generates proof attempts, the verifier filters for valid proofs, and the model is fine-tuned on the successful proofs. This is repeated iteratively, with each generation producing harder proofs that serve as training data for the next round. This is conceptually similar to the STaR bootstrapping method described in Section 8.3.
- MCTS with learned value functions: As in AlphaProof, a value network estimates proof completion probability from any state, guiding tree search toward promising tactic sequences. The value network is trained from the outcomes of previous searches (proofs found vs. search failures).
- Online RL with proof feedback: The model receives reward only when a complete proof is verified by Lean. This sparse reward signal is challenging to optimize, but curriculum learning (starting with easy theorems and progressing to harder ones) helps. GRPO (from Section 8.3) can be applied directly, with the proof verifier as the reward function.
# Expert iteration for formal theorem proving
from typing import List, Optional
class ExpertIterationTrainer:
"""
Train a theorem prover via expert iteration.
Each round: generate proofs, verify, fine-tune on successes.
"""
def __init__(self, prover_model, verifier, theorem_set):
self.model = prover_model
self.verifier = verifier
self.theorems = theorem_set
self.proof_buffer = [] # Accumulated successful proofs
def run_iteration(self, num_attempts: int = 64) -> dict:
"""One round of expert iteration."""
new_proofs = []
for theorem in self.theorems:
for attempt in range(num_attempts):
# Generate a proof attempt
proof = self.model.generate_proof(
theorem.statement,
temperature=0.8,
max_tactics=50
)
# Verify with the proof assistant
if self.verifier.check(theorem.name, proof):
new_proofs.append({
"theorem": theorem.statement,
"proof": proof,
"difficulty": theorem.difficulty,
})
break # Move to next theorem
# Fine-tune on newly discovered proofs
self.proof_buffer.extend(new_proofs)
self.model.fine_tune(self.proof_buffer)
return {
"theorems_attempted": len(self.theorems),
"proofs_found": len(new_proofs),
"success_rate": len(new_proofs) / len(self.theorems),
"buffer_size": len(self.proof_buffer),
}
# Typical expert iteration progression:
# Round 1: 20% of theorems proved (easy ones)
# Round 2: 35% proved (medium difficulty unlocked)
# Round 3: 45% proved (harder proofs from composed tactics)
# Round 4: 52% proved (diminishing returns begin)
Formal theorem proving with LLMs requires significant compute even at small scale. Compiling mathlib takes several hours, and each proof search involves multiple calls to the Lean type checker. For experimentation, start with a small subset of theorems (miniF2F's validation split has 244 problems) and limit search budget to 100 tactic attempts per problem. Scale up only after validating your pipeline on this smaller set.
7. Evaluation Metrics for Formal Proving
Evaluating formal provers uses metrics distinct from those for informal reasoning. The key metrics are:
- pass@k: The probability that at least one of k independent proof attempts succeeds. This is the standard metric, directly analogous to pass@k in code generation (Section 29.2). Typical values reported: pass@1, pass@10, pass@100.
- Proof validity: Every reported proof must be machine-checked by the target proof assistant. Unlike informal math evaluation, there are no "approximately correct" results. The proof either type-checks or it does not.
- Premise retrieval accuracy: For retrieval-augmented provers, the recall of relevant premises among the top-k retrieved. High premise recall is a prerequisite for proof success, since missing a critical lemma makes the proof impossible.
- Search efficiency: The number of tactic attempts or wall-clock time required to find a proof. Two systems with the same pass@100 may differ dramatically in how quickly they find proofs; the more efficient system is preferable for interactive use.
| System | pass@1 | pass@100 | Approach |
|---|---|---|---|
| ReProver (LeanDojo) | 26.5% | ~48% | Retrieval-augmented tactic generation |
| GPT-4 + Lean | ~29% | ~52% | Few-shot prompting with search |
| AlphaProof | N/A | ~65% (est.) | MCTS + RL + formal verification |
| DeepSeek-Prover-V2 | ~35% | ~60% | Expert iteration + subgoal decomposition |
The numbers above should be treated as approximate. Formal proving benchmarks are sensitive to search budget (more attempts always help when the verifier is perfect), the version of mathlib used (which affects available premises), and whether the system has seen the test problems during training. Direct comparisons across papers should account for these variables. AlphaProof's results are estimated from its IMO performance, as detailed benchmark numbers have not been publicly released.
8. Open Frontiers and Practical Implications
Formal proving with LLMs remains an active research frontier with several open challenges:
- Autoformalization: Translating informal mathematical text into formal Lean statements is a prerequisite for applying formal provers to real mathematical problems. Current systems achieve roughly 25% accuracy on ProofNet's autoformalization task, far below what is needed for practical use.
- Scaling to research-level mathematics: IMO problems, while difficult, follow well-established patterns. Open research problems require novel proof strategies that go beyond pattern matching on existing mathlib proofs.
- Integration with informal reasoning: The most promising direction may be hybrid systems that use informal chain-of-thought reasoning to develop a proof sketch, then formalize and verify each step. This mirrors how human mathematicians work: intuition first, rigor second.
- Applications beyond mathematics: Formal verification of software (using systems like Coq, Agda, or F*) shares the same proof search structure. LLM-based provers could accelerate the verification of critical software systems, from cryptographic protocols to operating system kernels.
Show Answer
Show Answer
Show Answer
Key Takeaways
- Formal theorem proving provides a perfect verification signal, making it an ideal testbed for RL-based reasoning. Unlike informal math, where reward models are approximate, a proof assistant accepts or rejects with mathematical certainty.
- LeanDojo extracts structured training data from Lean 4's mathlib, and its ReProver system demonstrates that retrieval-augmented premise selection significantly improves proving success.
- miniF2F and ProofNet provide standardized benchmarks at competition and undergraduate levels, enabling reproducible evaluation of formal proving systems.
- AlphaProof demonstrated that MCTS combined with RL and formal verification can solve IMO-level problems, achieving silver medal performance at the 2024 competition.
- Expert iteration and self-play are effective training paradigms for formal proving, bootstrapping from easy proofs to progressively harder ones over multiple rounds.
- The key open challenges are autoformalization (translating informal to formal math), scaling to research-level problems, and extending formal verification beyond pure mathematics to software and logical reasoning.
Open Questions:
- Can LLM-based provers discover genuinely novel mathematical results, or are they limited to reproving known theorems using known techniques?
- How can autoformalization accuracy be improved to make formal proving accessible to mathematicians who do not know Lean?
- Will formal verification become a standard component of LLM training, providing ground-truth reward signals for reasoning beyond mathematics (code correctness, logical consistency)?
Explore Further: Install Lean 4 and LeanDojo, extract the proof data for a small mathlib file, and train a simple tactic predictor. Measure pass@1 and pass@10 on the miniF2F validation split to establish a baseline for your own experiments.
Hands-On Lab: Generate Text with GPT-2
Objective
Implement greedy decoding and nucleus (top-p) sampling from scratch using raw GPT-2 logits, then compare your results with the transformers.pipeline("text-generation") shortcut to see how a few lines of library code replace dozens of manual steps.
What You'll Practice
- Loading a pretrained language model and extracting raw logits
- Implementing greedy decoding with an autoregressive loop
- Implementing nucleus (top-p) sampling with temperature scaling
- Comparing manual decoding against the HuggingFace pipeline abstraction
Setup
Install the transformers library and PyTorch. A GPU is not required; GPT-2 (124M parameters) runs comfortably on CPU.
pip install transformers torch
Steps
Step 1: Load GPT-2 and implement greedy decoding
Load the model, feed a prompt, and repeatedly select the highest-probability next token until you reach a length limit.
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
def greedy_decode(prompt, max_new_tokens=50):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] # logits for last position
next_token = logits.argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
print("=== Greedy Decoding ===")
print(greedy_decode("The future of artificial intelligence is"))
Hint
Greedy decoding always picks the single most likely token. This produces deterministic output but often leads to repetitive text. Notice how the model tends to loop on the same phrases.
Step 2: Implement nucleus (top-p) sampling
Add temperature scaling and top-p filtering to produce more diverse, natural text.
import torch.nn.functional as F
def nucleus_sample(prompt, max_new_tokens=50, temperature=0.8, top_p=0.9):
input_ids = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] / temperature
# Sort logits and compute cumulative probabilities
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above top_p
sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[sorted_mask] = float("-inf")
# Sample from the filtered distribution
probs = F.softmax(sorted_logits, dim=-1)
next_index = torch.multinomial(probs, num_samples=1)
next_token = sorted_indices.gather(-1, next_index)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0], skip_special_tokens=True)
print("\n=== Nucleus Sampling (p=0.9, temp=0.8) ===")
for i in range(3):
print(f"Sample {i+1}: {nucleus_sample('The future of artificial intelligence is')}")
print()
Hint
Nucleus sampling keeps the smallest set of tokens whose cumulative probability exceeds top_p, then samples from that set. Lower top_p values (0.5) produce more focused text; higher values (0.95) allow more variety.
Step 3: Compare with the HuggingFace pipeline
Achieve the same results with the high-level pipeline API in just a few lines.
from transformers import pipeline
# The library way: 3 lines
generator = pipeline("text-generation", model="gpt2")
results = generator(
"The future of artificial intelligence is",
max_new_tokens=50, do_sample=True, top_p=0.9, temperature=0.8,
num_return_sequences=3
)
print("=== Pipeline Output ===")
for i, r in enumerate(results):
print(f"Sample {i+1}: {r['generated_text']}\n")
Hint
The pipeline handles tokenization, the autoregressive loop, sampling logic, and decoding internally. Under the hood it uses the same model.generate() method with configurable strategies including greedy, beam search, top-k, and top-p sampling.
Expected Output
- Greedy decoding produces deterministic, often repetitive text
- Nucleus sampling produces varied outputs across runs, with coherent prose
- Pipeline output matches the quality of manual nucleus sampling with far less code
Stretch Goals
- Implement top-k sampling and compare output diversity against nucleus sampling
- Add a repetition penalty that reduces the logit of any token already generated
- Plot token probability distributions at each step to visualize how temperature reshapes the distribution
What's Next?
Now that we have explored how models reason, from chain-of-thought prompting to formal theorem proving, we shift focus to making these models run efficiently. In Chapter 9: Inference Optimization and Efficient Serving, we cover quantization, KV caching, speculative decoding, and serving frameworks that let you deploy reasoning models in production without breaking the compute budget.
Introduces LeanDojo for extracting data from Lean 4 and ReProver for retrieval-augmented theorem proving. Provides the foundational infrastructure for LLM-based formal proving research. Essential reading for anyone building formal proving systems.
The standard benchmark for evaluating formal provers on competition-level mathematics. Contains 488 problems formalized in Lean, Isabelle, and Metamath. Required context for interpreting any formal proving results.
Introduces a benchmark for autoformalization and formal proving at the undergraduate level. Useful for evaluating systems on more routine mathematical reasoning beyond competition problems.
Announcement of AlphaProof's IMO results. Demonstrates the power of combining MCTS, RL, and formal verification for mathematical reasoning. The most significant milestone in AI theorem proving to date.
Early work on applying GPT-style language models to formal theorem proving in Metamath. Established the viability of the approach and inspired subsequent work including LeanDojo and AlphaProof.
Combines MCTS with expert iteration for Lean theorem proving, achieving strong results on miniF2F. Demonstrates that open-weight models can compete with proprietary systems on formal proving tasks.
