Pretraining Objectives & Paradigms

Section 6.2

Tell a model to predict the next word and it learns grammar. Mask a word and it learns meaning. Corrupt a span and it learns to complain about the corruption, then fix it anyway.

ScaleScale, Gleefully Predictive AI Agent
Big Picture

Why does the training objective matter? A language model's pretraining objective is the task it solves trillions of times during training. This choice shapes everything: what the model learns to represent, what it can generate, and what downstream tasks it excels at. Causal language modeling produces powerful generators. Masked language modeling produces powerful encoders. Span corruption creates versatile encoder-decoders. Newer objectives like fill-in-the-middle and multi-token prediction push the boundaries further. Understanding these objectives is essential for selecting the right model for your application and for designing new training procedures. The tokenization choices from Chapter 1 directly affect how these objectives operate at the token level.

Key Insight: Remember

Next-token prediction is the world's most innocent-looking objective and the most powerful one we know. "Predict the next word" sounds shallow, but to do it well across the internet you have to learn grammar, facts, code, arithmetic, dialogue, and reasoning. The objective is local; the side effects are everything.

Prerequisites

This section builds directly on the Transformer architecture from Section 3.1 (encoder, decoder, and attention masks). Understanding of tokenization from Chapter 1 is assumed. The discussion of multi-token prediction connects forward to the production-model survey later in this part.

6.2.1 Causal Language Modeling (CLM)

Two pretraining objectives at the token level. Causal language modeling (left) predicts the next token from prior tokens only, the natural setup for autoregressive generation. Masked language modeling (right) hides random tokens and predicts them using bidirectional context, the natural setup for representation learning. Both objectives use the same Transformer backbone; only the attention mask and the loss differ.
Figure 6.2.1: Two pretraining objectives at the token level. Causal language modeling (left) predicts the next token from prior tokens only, the natural setup for autoregressive generation. Masked language modeling (right) hides random tokens and predicts them using bidirectional context, the natural setup for representation learning. Both objectives use the same Transformer backbone; only the attention mask and the loss differ.

Causal language modeling, also called autoregressive language modeling, trains a model to predict the next token given all previous tokens. Formally, given a sequence of tokens $x = (x_{1}, x_{2}, ..., x_{T})$, the model maximizes:

$$L_{\text{CLM}} = -\sum _{t=1}^{T} \log P(x_{t} | x_{1}, ..., x_{\text{t-1}}; \theta )$$

The model processes tokens left-to-right, with a causal attention mask that prevents each position from attending to future positions. This makes CLM naturally suited for text generation: at inference time, the model generates one token at a time, feeding each prediction back as input for the next step (see Section 4.1 for the mechanics of autoregressive decoding).

Why CLM Dominates Modern LLMs

Several properties make CLM the preferred objective for large-scale models:

# Implementing causal language modeling loss from scratch
import torch
import torch.nn.functional as F
def causal_lm_loss(logits, labels):
    """
       logits: (batch, seq_len, vocab_size)
       labels: (batch, seq_len) - same as input tokens shifted by 1
       """
    # Shift: predict token t+1 from position t
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()
    # Cross-entropy loss, ignoring padding tokens (label = -100)
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100
    )
    return loss
# Example: batch of 2, sequence length 5, vocab size 100
logits = torch.randn(2, 5, 100)
labels = torch.randint(0, 100, (2, 5))
loss = causal_lm_loss(logits, labels)
print(f"CLM Loss: {loss.item():.4f}")
print(f"Perplexity: {torch.exp(loss).item():.2f}")
Output: CLM Loss: 4.6075 Perplexity: 100.27
Code Fragment 6.2.1a: Implementing causal language modeling loss from scratch.
# Causal language modeling loss from scratch.
# Predict next token at each position; only the SHIFTED-LEFT targets contribute.
import torch
import torch.nn.functional as F

def causal_lm_loss(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
    """logits: (batch, seq_len, vocab)
       input_ids: (batch, seq_len)
       Returns scalar cross-entropy averaged over all valid positions."""
    # Shift: predict input_ids[:, 1:] from logits[:, :-1, :]
    pred_logits = logits[:, :-1, :].contiguous()              # (B, T-1, V)
    targets     = input_ids[:, 1:].contiguous()               # (B, T-1)

    # Cross-entropy over flattened batch+positions
    loss = F.cross_entropy(
        pred_logits.view(-1, pred_logits.size(-1)),
        targets.view(-1),
    )
    return loss

# Quick sanity check
torch.manual_seed(0)
B, T, V = 2, 10, 50257
fake_logits = torch.randn(B, T, V)
fake_ids = torch.randint(0, V, (B, T))
print(f"Loss: {causal_lm_loss(fake_logits, fake_ids).item():.3f}")
# For an untrained random model on a 50K-token vocab, expect log(50257) ~= 10.8.
Output: Loss: 10.834
Code Fragment 6.2.2: Implementing causal language modeling loss from scratch
Tip: Production Alternative

The implementation above builds causal language modeling loss from scratch for pedagogical clarity. In production, use Hugging Face Transformers (install: pip install transformers), where the loss computation and label shifting are handled automatically (see Code Fragment 6.2.3 below). Code Fragment 6.2.4 sketches a multi-token-prediction head extension.

# Production equivalent: loss is computed internally
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
inputs = tokenizer("The cat sat on the mat", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
print(f"Loss: {outputs.loss.item():.4f}")
Output: Loss: 4.1832
Code Fragment 6.2.3: The implementation above builds causal language modeling loss from scratch for pedagogical clarity.
# Multi-token prediction: conceptual implementation
import torch
import torch.nn as nn
class MultiTokenPredictionHead(nn.Module):
    """N independent prediction heads sharing a transformer backbone."""
    def __init__(self, hidden_dim, vocab_size, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        # Each head: LayerNorm + Linear projection to vocab
        self.heads = nn.ModuleList([
            nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, vocab_size)
            )
            for _ in range(n_heads)
            ])
    def forward(self, hidden_states, labels):
        """
        hidden_states: (batch, seq_len, hidden_dim)
        labels: (batch, seq_len) - original token ids
        """
        total_loss = 0.0
        for k, head in enumerate(self.heads, start=1):
            logits = head(hidden_states) # (batch, seq_len, vocab)
            # Shift by k positions: predict token at t+k from position t
            shift_logits = logits[:, :-k, :].contiguous()
            shift_labels = labels[:, k:].contiguous()
            loss_k = nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100
                )
            total_loss += loss_k
            return total_loss / self.n_heads
            mtp = MultiTokenPredictionHead(hidden_dim=512, vocab_size=32000, n_heads=4)
            hidden = torch.randn(2, 128, 512)
            labels = torch.randint(0, 32000, (2, 128))
            loss = mtp(hidden, labels)
            print(f"MTP Loss (4 heads): {loss.item():.4f}")
Output: MTP Loss (4 heads): 10.3742
Code Fragment 6.2.4: Multi-token prediction (MTP): the MultiHeadMTP model class wires N independent linear heads on top of a shared backbone, each predicting a different future-token offset.
Key Insight

Next-token prediction is, mathematically, an exercise in learning the conditional probability distribution of language. This connects directly to Shannon's foundational work on information theory (1948), where he defined the entropy of English as the limit of the per-symbol prediction difficulty. Shannon estimated English has about 1 to 1.5 bits of entropy per character. Modern LLMs, trained on next-token prediction, are essentially building ever more accurate approximations of this distribution. The cross-entropy loss used in training is a direct upper bound on the true entropy of the data source. When a model achieves lower perplexity, it is, in information-theoretic terms, building a better compression model of language. This is why Ilya Sutskever has argued that "compression is intelligence": a model that can predict the next token well must have internalized the statistical structure, grammar, facts, and reasoning patterns embedded in its training corpus.

6.2.2 Masked Language Modeling (MLM)

Fun Fact

Masked language modeling is essentially a fill-in-the-blank exercise at massive scale. BERT learned to read by doing the same kind of worksheet your third-grade teacher handed out, just across billions of sentences instead of a dozen.

Masked language modeling, introduced by BERT, randomly masks a fraction of input tokens and trains the model to reconstruct them from the surrounding (bidirectional) context. The standard recipe masks 15% of tokens, with 80% replaced by [MASK], 10% replaced by a random token, and 10% kept unchanged.

Note: Encoder-only fixed window and [PAD] tokens

Encoder-only models like BERT operate on a fixed-length token window (512 for BERT-Base, sometimes extended to 1024 or 4096 in later models). Inputs shorter than the window are padded on the right with a special [PAD] token, and an attention mask is passed alongside the input so that attention scores at [PAD] positions are masked to zero. Inputs longer than the window are either truncated or split into overlapping chunks. This is in sharp contrast to decoder-only models, which can in principle process any length up to their training context (modulo RoPE / positional limits, covered in Section 3.3).

$$L_{\text{MLM}} = -\sum _{i \in M} \log P(x_{i} | x_{\mathcal{M}}; \theta )$$

where $M$ is the set of masked positions and $x_{\mathcal{M}}$ denotes the corrupted input (all tokens with masks applied).

Strengths and Limitations

MLM's bidirectionality is its greatest strength for understanding tasks. A model that sees both left and right context can build richer representations of each token. This is why BERT-style models dominate in classification, named entity recognition, and other tasks where the full context is available.

However, MLM has significant limitations. Only 15% of tokens provide training signal per forward pass, making it less sample-efficient than CLM. The [MASK] token never appears during inference, creating a train-test mismatch. And MLM models cannot naturally generate text because they assume access to future context during prediction.

CLM sees left context, trains every token; MLM sees full context, trains only masked
Figure 6.2.2a: CLM sees only left context and trains on every token. MLM sees full context but trains only on masked positions.
Real-World Scenario
Choosing Between CLM and MLM for a Legal Document Search Engine

Who: An AI team at a legal technology firm building a semantic search engine for court opinions and case law.

Situation: The system needed to understand nuanced legal language to match queries with relevant precedents, requiring deep bidirectional understanding of context.

Problem: A GPT-style (CLM) model fine-tuned for embeddings performed poorly on legal retrieval tasks because it processed text left-to-right, missing backward contextual cues critical for understanding legal clauses like "notwithstanding the foregoing."

Dilemma: CLM models offered better generation (useful for summarization features), but MLM models produced superior embeddings for retrieval. Running both models doubled serving costs.

Decision: They adopted a two-stage approach: a fine-tuned DeBERTa (MLM-based) encoder for retrieval and ranking, paired with a small CLM model for optional summary generation.

How: The team pretrained DeBERTa on 2M legal documents using MLM, then fine-tuned it with contrastive learning on query-document pairs. The CLM model was only invoked when users requested case summaries.

Result: Retrieval recall@10 improved from 72% to 89% compared to the CLM-only baseline. The hybrid approach added only 15% to infrastructure costs versus the original CLM-only system.

Lesson: The pretraining objective fundamentally shapes what a model is good at: MLM produces better encoders for understanding tasks, while CLM produces better generators. Choose based on your primary use case.

CLM and MLM represent two ends of a spectrum: predict the next token versus predict masked tokens. But what if we masked entire spans of text rather than individual tokens? This hybrid idea led to a third family of pretraining objectives that combines the strengths of both approaches.

BERT's Joint Objective: MLM + Next Sentence Prediction

The original BERT was not trained on MLM alone. It optimized a joint loss combining masked language modeling with Next Sentence Prediction (NSP), a binary classification objective designed to teach the model relations between sentences. Each training example is a pair of segments $A$ and $B$ packed as [CLS] A [SEP] B [SEP], with the special [CLS] (classification) token prepended to the input and [SEP] separating the two segments. Half the time $B$ is the actual sentence that followed $A$ in the corpus, and half the time it is a random sentence drawn from elsewhere. A single binary classification head sitting on top of the context vector at the [CLS] position predicts whether $B$ is the true next sentence. The combined objective is:

$$L_{\text{BERT}} = L_{\text{MLM}} + L_{\text{NSP}}$$

The NSP loss is the standard binary cross-entropy on the [CLS] head, and the MLM loss is computed as before over the masked positions in the concatenated $A$ and $B$ sequence. The intent of NSP is to force the model to build a single context vector at the [CLS] position that summarizes the entire input, not just a single token's neighborhood. That summary is then directly usable as a sentence-pair (or single-sentence) representation for downstream tasks.

[CLS] as a Sentence Embedding (and Mean Pooling as an Alternative)

Because NSP trained the [CLS] context vector to encode the whole input, the natural recipe for using BERT as a sentence encoder is to take that single vector as the sentence embedding. Concretely, the post-training procedure for classification or retrieval is:

  1. Tokenize the input as [CLS] tokens... [SEP].
  2. Run a forward pass through BERT.
  3. Read off the final-layer hidden state at position 0 (the [CLS] position). That vector is the sentence embedding.
  4. Either feed it into a linear classifier head (fine-tuning), use it as a frozen feature, or compare two such vectors with cosine similarity for retrieval.

An equally common alternative is mean pooling: average the final-layer hidden states across all non-padding token positions to produce the sentence vector. Mean pooling often outperforms the raw [CLS] vector when the model was not specifically trained with NSP (for example, RoBERTa), and it is the default pooling strategy used in the Sentence-Transformers library (see Section 31.1 for sentence embeddings in retrieval, and Section 31.2 for the contrastive fine-tuning that turns raw [CLS] vectors into high-quality embeddings).

Note: Why later models dropped NSP

RoBERTa (2019) showed empirically that NSP contributed little and sometimes hurt downstream performance, and that the gains attributed to it were mostly the byproduct of training on longer concatenated inputs. RoBERTa, DeBERTa, and almost every subsequent encoder dropped NSP entirely. ALBERT replaced it with Sentence Order Prediction (predict whether two consecutive segments were swapped), a harder task that focuses on inter-sentence coherence rather than topical relatedness. The [CLS] token survived because the pooling convention proved useful even without an explicit NSP loss.

Warning
Raw [CLS] vectors are not semantic embeddings out of the box

A pretrained BERT's [CLS] vector is a great input to a fine-tuned classifier, but it is a poor semantic embedding by itself. Cosine similarities between raw [CLS] vectors of similar sentences are often near zero or even negative, because the pretraining objectives never required the vectors to live in a metric-meaningful space. To use BERT for semantic search or clustering, fine-tune it with a contrastive or triplet objective (Sentence-BERT, SimCSE), or use mean pooling combined with whitening / normalization, as covered in Section 31.2.

6.2.3 Span Corruption and Denoising Objectives

Swiss cheese metaphor for span corruption, where random spans of text are masked out like holes in cheese
Figure 6.2.3a: Span corruption turns your training text into Swiss cheese: random chunks get removed, and the model learns to fill the holes.

T5 introduced span corruption, a variant of MLM that masks contiguous spans of tokens rather than individual tokens. A random 15% of tokens are selected, grouped into contiguous spans, and each span is replaced with a single unique sentinel token (like <extra_id_0>). The model then generates the corrupted spans in order, separated by sentinel tokens.

Why Spans Are Better Than Single Tokens

Span corruption is more efficient than single-token masking for several reasons. First, the target sequence is shorter because multiple masked tokens are represented by a single sentinel, reducing the computational cost of the decoder. Second, the model must predict multiple consecutive tokens per span, encouraging it to learn phrase-level and sentence-level patterns rather than just word-level predictions.

# Simulating T5 span corruption
import random
def span_corruption(tokens, mask_ratio=0.15, mean_span_length=3):
    """Apply T5-style span corruption to a token list."""
    n = len(tokens)
    num_masked = int(n * mask_ratio)
    num_spans = max(1, num_masked // mean_span_length)
    # Generate random span starts and lengths
    mask = [False] * n
    masked_so_far = 0
    for _ in range(num_spans):
        if masked_so_far >= num_masked:
            break
            span_len = random.randint(1, mean_span_length * 2)
            start = random.randint(0, n - 1)
            for i in range(start, min(start + span_len, n)):
                mask[i] = True
                masked_so_far += 1
                # Build corrupted input and target
                corrupted, target = [], []
                sentinel_id = 0
                in_span = False
                for i, tok in enumerate(tokens):
                    if mask[i]:
                        if not in_span:
                            corrupted.append(f"<extra_id_{sentinel_id}>")
                            target.append(f"<extra_id_{sentinel_id}>")
                            sentinel_id += 1
                            in_span = True
                            target.append(tok)
                        else:
                            corrupted.append(tok)
                            in_span = False
                            return corrupted, target
                            tokens = "The quick brown fox jumps over the lazy dog".split()
                            random.seed(42)
                            corrupted, target = span_corruption(tokens)
                            print(f"Input:     {' '.join(corrupted)}")
                            print(f"Target:    {' '.join(target)}")
Output: Input: The quick <extra_id_0> jumps over <extra_id_1> dog Target: <extra_id_0> brown fox <extra_id_1> the lazy
Code Fragment 6.2.5: Simulating T5 span corruption.

UL2: Mixture of Denoisers

UL2 (Unified Language Learning, 2022) took the denoising approach further by mixing multiple corruption strategies during pretraining. It combined three modes: (1) R-denoiser (regular denoising, like T5, with short spans), (2) S-denoiser (sequential denoising, similar to prefix LM, masking a suffix), and (3) X-denoiser (extreme denoising, with long spans and high mask ratios). A mode token prepended to each example tells the model which type of denoising to perform. This produced a single model that excelled at both understanding and generation tasks.

6.2.4 Prefix Language Modeling

Prefix LM is a hybrid approach used by models like PaLM and GLM. The input is divided into a prefix (which uses bidirectional attention) and a suffix (which uses causal attention). The prefix provides full context for encoding the input, while the suffix generates output autoregressively. This combines the encoding strength of MLM with the generation capability of CLM.

Note

Prefix LM is implemented simply by modifying the attention mask. For a sequence of length T where the first P tokens are the prefix, positions 1 through P attend to all positions 1 through P (bidirectional), while positions P+1 through T attend only to positions 1 through their own index (causal). No architectural changes are needed.

6.2.5 Fill-in-the-Middle (FIM)

Fill-in-the-middle is a training objective designed specifically for code models (like Codex, StarCoder, and CodeLlama, discussed further in Section 27.1). The key observation is that programmers frequently need to insert code at an arbitrary position within existing code, but standard CLM only supports appending tokens at the end.

FIM works by splitting a document into three parts: prefix, middle, and suffix. During training, these parts are rearranged so the model sees the prefix and suffix first, then generates the middle. The most common variant, called PSM (Prefix-Suffix-Middle), presents the input as:

$$\texttt{<PRE>} \; \text{prefix} \; \texttt{<SUF>} \; \text{suffix} \; \texttt{<MID>} \; \text{middle}$$

An alternative variant, SPM (Suffix-Prefix-Middle), places the suffix before the prefix. The model learns to condition on both surrounding context to generate the infilling content.

# Fill-in-the-Middle (FIM) transformation
import random
def apply_fim(document, fim_rate=0.5, mode="PSM"):
    """Transform a document for FIM training."""
    if random.random() > fim_rate:
        return document # Keep as regular CLM with probability (1 - fim_rate)
    # Choose a random split point for the middle section
    chars = list(document)
    n = len(chars)
    split1 = random.randint(0, n)
    split2 = random.randint(split1, n)
    prefix = document[:split1]
    middle = document[split1:split2]
    suffix = document[split2:]
    if mode == "PSM":
        return f"<PRE>{prefix}<SUF>{suffix}<MID>{middle}"
    elif mode == "SPM":
        return f"<SUF>{suffix}<PRE>{prefix}<MID>{middle}"
    # Example with code
    code = """def fibonacci(n):
        if n <= 1:
            return n
        return fibonacci(n-1) + fibonacci(n-2)"""
    random.seed(0)
    fim_example = apply_fim(code, fim_rate=1.0)
    print(fim_example)
Output: <PRE>def fibonacci(n): if n <= 1: return n retu<SUF>onacci(n-2)<MID>rn fibonacci(n-1) + fib
Code Fragment 6.2.6: Fill-in-the-Middle (FIM) transformation.
Key Insight

FIM is remarkably efficient to add. The original paper showed that applying FIM transformations to 50% of training documents during standard CLM pretraining adds the infilling capability with essentially zero degradation to left-to-right generation quality. This makes it a "free" capability that all code models should include.

6.2.6 Multi-Token Prediction

Standard CLM predicts only one token ahead. Multi-token prediction (MTP), introduced by Meta in 2024 and later adopted by DeepSeek V3, trains the model to predict several future tokens simultaneously. The architecture adds N independent prediction heads to a shared transformer backbone. Head k predicts token $x_{t+k}$ given tokens $x_{1}, ..., x_{t}$.

$$L_{\text{MTP}} = -\sum _{k=1}^{N} \sum _{t=1}^{T-k} \log P_{k}(x_{t+k} | x_{1}, ..., x_{t}; \theta )$$

Why Predict Multiple Tokens?

The benefits of multi-token prediction are both theoretical and practical:

Multi-token prediction: N independent heads on shared backbone predict different future tokens
Figure 6.2.4a: Multi-token prediction uses N independent heads on top of a shared backbone, each predicting a different future token.
# Multi-token prediction (Meta, "Better & Faster LLMs via Multi-Token Prediction" 2024).
# Train the model to predict the NEXT k tokens at every position, using k parallel heads.
# Improves sample efficiency and downstream code/math performance.
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTokenPredictionHead(nn.Module):
    """k separate linear heads sharing the same trunk; head[i] predicts token (t+i+1)."""

    def __init__(self, d_model: int, vocab: int, n_future: int = 4):
        super().__init__()
        self.n_future = n_future
        self.heads = nn.ModuleList([nn.Linear(d_model, vocab) for _ in range(n_future)])

    def forward(self, hidden: torch.Tensor) -> list[torch.Tensor]:
        """hidden: (B, T, d_model) -> list of n_future tensors each (B, T, vocab)"""
        return [head(hidden) for head in self.heads]

    def loss(self, hidden: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
        """Sum cross-entropy across the k future positions, dropping the tail."""
        logits_list = self(hidden)
        B, T, _ = hidden.shape
        total = 0.0
        for i, logits in enumerate(logits_list, start=1):
            # head i predicts token at position (t+i); drop the last i positions
            valid_logits = logits[:, :T - i, :].contiguous().view(-1, logits.size(-1))
            targets      = input_ids[:, i:].contiguous().view(-1)
            total = total + F.cross_entropy(valid_logits, targets)
        return total / self.n_future
Code Fragment 6.2.7: Multi-token prediction loss: per-head cross-entropy summed across the N future-token heads (each head i predicts position t+i, with the last i positions dropped to avoid off-the-end targets).

6.2.7 Continual Pretraining

All of the objectives described above focus on training a model from scratch. But a growing body of work demonstrates that continual pretraining (also called domain-adaptive pretraining or continued pretraining) is a powerful and cost-effective technique for adapting an existing foundation model to a new domain, language, or capability without the expense of a full training run.

The core idea is straightforward: take a pretrained model's checkpoint and resume CLM training on a new corpus that emphasizes the target domain. For example, a general-purpose model trained primarily on English web text can be continually pretrained on a large corpus of biomedical papers, legal documents, or code. The model retains its general capabilities while absorbing domain-specific knowledge, vocabulary patterns, and reasoning styles. This is far cheaper than pretraining from scratch because the model already has strong language understanding; the continual phase only needs to adjust the distribution, typically requiring 10 to 100 billion tokens rather than trillions.

Several practical considerations govern successful continual pretraining:

Notable examples include Meta's Code Llama (continual pretraining of Llama-2 on 500B tokens of code), Llama-3.1's context extension (continual pretraining from 8K to 128K context using progressively longer sequences), and numerous domain-specific models in medicine (BioMistral, Med-PaLM), law (SaulLM), and finance (BloombergGPT). Continual pretraining occupies a middle ground between fine-tuning (which adjusts model behavior with relatively little data) and full pretraining (which builds capabilities from the ground up). It is increasingly the recommended first step when adapting a foundation model to a specialized domain.

6.2.8 Comparison of Pretraining Objectives

Table 6.2.1b: Comparison of Pretraining Objectives (as of 2026).
Objective Architecture Context Training Efficiency Best For
CLM Decoder-only Left-to-right 100% tokens train Generation, prompting
MLM Encoder-only Bidirectional 15% tokens train Classification, NER
Span Corruption Encoder-Decoder Bidirectional enc 15% tokens, shorter target Seq2seq, translation
Prefix LM Decoder-only Hybrid Suffix tokens train Conditional generation
FIM Decoder-only Prefix + Suffix 100% tokens train Code infilling
MTP Decoder-only Left-to-right N x signals per position Better representations
Note: MTP Validated at Scale

DeepSeek V3 (2024) provided the strongest validation of multi-token prediction, using 4 prediction heads during pretraining of a 671B MoE model. The additional heads served double duty: improving representation quality during training and enabling self-speculative decoding during inference, eliminating the need for a separate draft model. See Section 7.3 for the full DeepSeek V3 architecture discussion.

Note: Beyond Attention

All objectives discussed here assume a transformer backbone. An active line of research explores alternative architectures like Mamba (Gu and Dao, 2023) and RWKV, which replace attention with state-space models (SSMs) or linear recurrences. These achieve linear scaling with sequence length (versus quadratic for attention) and process tokens in constant memory during inference. Hybrid architectures like Jamba interleave attention and Mamba layers for the best of both worlds. While transformers remain dominant, SSMs are a rapidly maturing alternative covered by Stanford CS336 and Berkeley CS294.

Key Insight
Prediction as Compression, Compression as Understanding

All pretraining objectives share a common mathematical foundation: they train the model to compress its training data. CLM compresses by predicting the next token (minimizing softmax is equivalent to minimizing the coding length under the model's distribution). MLM compresses by predicting masked tokens from context. This connects to Solomonoff's theory of inductive inference (1964), which proved that the best predictor of any sequence is the shortest program that generates it. A language model that achieves low perplexity has, in a precise sense, found a compressed representation of human language. This view explains why pretraining on "mere" next-token prediction produces models capable of reasoning, translation, and code generation: to predict text well, the model must build internal representations of grammar, world knowledge, and logical structure, because those are exactly the regularities that enable compression. Understanding and compression are two perspectives on the same phenomenon.

Tip: Monitor Loss Curves, Not Just Final Loss

Plot your training loss every few hundred steps. A healthy curve should decrease smoothly. Sudden spikes indicate data quality issues or learning rate problems. Save these plots; they are invaluable for debugging failed runs after the fact.

See Also

For decoding strategies that shape the model output once trained, see Section 4.1. For inference-side optimizations that ride on these training choices, see Section 9.3. For how the modern landscape applies these training recipes, see Section 7.3.

Research Frontier

Beyond next-token prediction. Multi-token prediction (Gloeckle et al., 2024) demonstrated that predicting multiple future tokens simultaneously improves both representation quality and inference speed through speculative decoding (Section 10.3). DeepSeek V3 adopted this approach at scale. Meanwhile, diffusion-based language models (MDLM, SEDD) explore non-autoregressive generation that can edit text in parallel rather than left-to-right. Hybrid objectives combining masked and causal pretraining (UL2) suggest that the next generation of foundation models may blend multiple paradigms within a single training run, choosing the objective dynamically based on the data distribution.

Key Takeaways
Self-Check
1. Why does CLM provide more training signal per sequence than MLM?
Show Answer
In CLM, every token in the sequence contributes to the loss (predicting the next token at each position), so 100% of tokens provide gradient signal. In MLM, only the 15% of tokens that are masked contribute to the loss. For a 512-token sequence, CLM gets ~511 prediction targets while MLM gets only ~77.
2. How does FIM avoid degrading standard left-to-right generation quality?
Show Answer
FIM applies the fill-in-the-middle transformation to only a fraction (typically 50%) of training documents. The other 50% remain as standard left-to-right sequences. The model thus learns both capabilities simultaneously. Experiments show that this split introduces essentially zero degradation to autoregressive generation while adding the infilling capability.
3. What advantage does multi-token prediction offer for inference speed?
Show Answer
The additional prediction heads from MTP training can serve as draft models for speculative decoding. During inference, the auxiliary heads propose multiple future tokens in parallel, and the main head verifies them. If the drafts match the main model's distribution, multiple tokens are accepted in a single forward pass, providing 2-3x speedup with no quality loss. This is especially valuable because no separate draft model needs to be loaded.
4. Explain the difference between T5's span corruption and BERT's token-level masking.
Show Answer
BERT masks individual tokens independently (each token has a 15% chance of being masked). T5's span corruption selects contiguous spans of tokens and replaces each span with a single sentinel token. This means T5's corrupted input is shorter (fewer sentinel tokens than the number of masked tokens in BERT), and the target sequence is also shorter. Additionally, T5 must predict multiple consecutive tokens per span, learning phrase-level patterns rather than just individual word predictions.
5. How does UL2's mixture-of-denoisers approach combine the strengths of multiple pretraining objectives?
Show Answer
UL2 mixes three denoising modes during pretraining: R-denoiser (regular short-span corruption like T5), S-denoiser (sequential denoising similar to prefix LM, masking a suffix), and X-denoiser (extreme denoising with long spans and high mask ratios). A mode token prepended to each example tells the model which type of denoising to perform. By training on all three modes simultaneously, UL2 produces a single model that excels at both understanding tasks (served by R and X modes) and generation tasks (served by S mode), avoiding the need to choose a single objective upfront.

Exercises

Exercise 6.2.1: CLM vs MLM Sample Efficiency Conceptual

A causal language model produces a training signal at every token position, while a masked language model with 15% masking only learns from the masked positions. Naively this suggests MLM is ~6.7x less sample-efficient per token. (a) Why didn't BERT die in 2019 from this disadvantage? (b) Why did the field eventually move toward CLM-style objectives anyway? (c) For an embedding-only use case, which objective would you still prefer in 2026?

Answer Sketch

(a) Each masked position uses bidirectional context, which is a much richer signal per prediction than CLM's left-only context, partly compensating for the lower density. BERT also typically processed every example in a single forward pass for both pretraining and fine-tuning, so wall-clock efficiency was competitive on NLU benchmarks. (b) CLM scales more cleanly: every token is a label, so adding more data linearly grows the loss signal, and the same model architecture handles generation, classification, and extraction. (c) For pure embedding work (semantic search, dense retrieval), MLM-style bidirectional encoders like DeBERTa-v3 or modern sentence-transformer recipes still produce stronger per-parameter retrieval scores, because both tokens of a pair contribute information to the contextual representation.

Exercise 6.2.2: Span Corruption Token Budget Calculation

T5 uses span corruption with mean span length 3 and 15% corruption rate over 512-token inputs. (a) How many spans does an average example contain? (b) What is the expected length of the encoder input and decoder target after sentinel substitution? (c) Why does this objective produce shorter encoder inputs than the original sequence?

Answer Sketch

(a) 15% of 512 = 76.8 corrupted tokens; with mean span length 3, that is roughly 26 spans per example. (b) Encoder input: 512 - 76.8 + 26 sentinels ~= 461 tokens. Decoder target: 76.8 corrupted tokens + 26 leading sentinels + 1 final sentinel ~= 104 tokens. (c) Each span of N tokens collapses to a single sentinel in the encoder, so any span longer than 1 shortens the input. This is why span corruption is more compute-efficient per useful prediction than per-token masking: the encoder does less work, and every sentinel position in the decoder produces a multi-token prediction trained as a single autoregressive run.

Exercise 6.2.3: Implement Fill-in-the-Middle Reformatting Code Tweak

Sketch a Python function to_fim(text, prefix_len, suffix_len) that converts a text sample into the OpenAI FIM training format used for code completion: <|fim_prefix|>PREFIX<|fim_suffix|>SUFFIX<|fim_middle|>MIDDLE. The function should pick a random middle span, leave at least prefix_len and suffix_len tokens of context, and return the reordered string. Five lines of pseudocode is enough.

Answer Sketch
import random
def to_fim(text, prefix_len, suffix_len):
    toks = tokenize(text)
    start = random.randint(prefix_len, len(toks) - suffix_len - 1)
    end = random.randint(start + 1, len(toks) - suffix_len)
    P, M, S = toks[:start], toks[start:end], toks[end:]
    return f"<|fim_prefix|>{P}<|fim_suffix|>{S}<|fim_middle|>{M}"
Code Fragment 6.2.8: Sketch a Python function to_fim(text, prefix_len, suffix_len) that converts a text sample into the OpenAI FIM training format used for code completion.

The key trick is that the model still trains autoregressively (left to right), but the reordered format means the loss target appears after both prefix and suffix in the token stream. At inference time, you supply prefix and suffix in the same template and let the model generate until it emits the EOS or fim end token.

Exercise 6.2.4: Continual Pretraining Catastrophic Forgetting Failure Mode

You take a strong base Llama-3-8B and continue pretraining it on 50B tokens of medical literature to build a domain expert. After training, the medical benchmark scores rise by 12 points but the model has forgotten how to write Python and refuses to follow basic instructions. (a) Diagnose the two distinct failure modes. (b) Suggest one mitigation for each. (c) Why is this risk smaller for fine-tuning than for continual pretraining?

Answer Sketch

(a) Failure 1 is catastrophic forgetting of pretraining knowledge (Python skill loss) caused by the gradient updates pulling weights toward the medical-only data manifold. Failure 2 is loss of instruction-following because the medical corpus contains no instruction-style data, so the chat-tuning is overwritten. (b) For forgetting: mix in 5-15% of the original pretraining corpus (a "data replay" buffer). For instruction-following: re-apply a short SFT pass on the chat dataset after the domain phase, or interleave instruction data throughout. (c) Fine-tuning typically uses much smaller learning rates (1e-5 vs 1e-4) and far fewer tokens, so the weight perturbation is bounded; continual pretraining changes a much larger fraction of the parameters.

What's Next?

In the next section, Section 6.3: Scaling Laws and Compute-Optimal Training, we study scaling laws and compute-optimal training, the empirical relationships that guide how to allocate resources for maximum capability.

Further Reading

Core Pretraining Objectives

Radford, A. et al. (2018). "Improving Language Understanding by Generative Pre-Training." OpenAI. The original GPT paper that established autoregressive language modeling as a pretraining objective for downstream NLP tasks. Showed that generative pretraining followed by discriminative fine-tuning yields strong transfer performance.
Devlin, J. et al. (2019). "BERT: Pretraining of Deep Bidirectional Transformers for Language Understanding." NAACL 2019. Defines the masked language modeling (MLM) objective that enables bidirectional context during pretraining. Contrasts with causal LM by allowing tokens to attend in both directions, giving encoders richer representations for classification tasks.
Raffel, C. et al. (2020). "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer." JMLR. The T5 paper that systematically compares span corruption, prefix LM, and other pretraining objectives in a unified text-to-text framework. An invaluable reference for understanding which objectives work best for which task types.

Advanced & Hybrid Objectives

Bavarian, M. et al. (2022). "Efficient Training of Language Models to Fill in the Middle." arXiv preprint arXiv:2207.14255. Shows how to add fill-in-the-middle (FIM) capability to autoregressive models by simply rearranging training sequences. Enables code completion at arbitrary cursor positions without sacrificing left-to-right generation quality.
Gloeckle, F. et al. (2024). "Better & Faster Large Language Models via Multi-token Prediction." ICML 2024. Trains models to predict multiple future tokens simultaneously using auxiliary prediction heads. Improves both sample efficiency during training and enables speculative decoding at inference time for faster generation.
Tay, Y. et al. (2023). "UL2: Unifying Language Learning Paradigms." ICLR 2023. Proposes a mixture-of-denoisers framework that combines causal LM, prefix LM, and span corruption objectives during pretraining. Demonstrates that mixing paradigms produces models that excel at both generation and understanding tasks.