Part 2: Understanding LLMs
Chapter 06: Pretraining and Scaling Laws

Production LLM Training Systems: Megatron, Elastic Training, and Fault Tolerance

"Training a 405-billion-parameter model is not a research problem. It is a systems engineering problem with research inside it."

Scale Scale, Industrially Scaled AI Agent

Prerequisites

This section builds directly on the distributed training primitives (DDP, FSDP, ZeRO, tensor and pipeline parallelism) covered in Section 6.6: Distributed Training at Scale. Familiarity with mixed-precision training (FP16, BF16) and gradient checkpointing from that section is assumed. Understanding of pretraining data pipelines from Section 6.1 and scaling laws from Section 6.3 provides additional context.

Big Picture

Section 6.6 covered what distributed training is. This section covers how production teams actually do it. Training a frontier LLM (70B+ parameters) over multiple weeks on thousands of GPUs introduces engineering challenges that go far beyond choosing a parallelism strategy. You need a framework that composes multiple parallelism axes efficiently, a compiler stack that extracts maximum hardware utilization, fault tolerance that survives inevitable node failures, checkpointing that does not bottleneck training, and data pipelines that stream deterministically across ranks. This section covers the production-grade tools and techniques that make multi-week LLM training runs possible: Megatron-LM for multi-axis parallelism, torch.compile and kernel optimization for throughput, Torch Distributed Elastic for fault tolerance, PyTorch Distributed Checkpointing (DCP) for efficient state management, and streaming dataset libraries for scalable data loading.

1. Megatron-LM and Multi-Axis Parallelism

Section 6.6 introduced tensor parallelism (TP), pipeline parallelism (PP), and data parallelism (DP) as independent strategies. In practice, training a frontier LLM requires composing all of these simultaneously. Megatron-LM, developed by NVIDIA (Shoeybi et al., 2020; Narayanan et al., 2021), is the most widely adopted framework for multi-axis parallelism in LLM pretraining. Its successor, Megatron-Core, provides a modular library that can be integrated into custom training loops.

Mental Model: The Parallelism Grid

Think of multi-axis parallelism as a three-dimensional grid. The TP axis splits each layer across GPUs within a single node (fast NVLink). The PP axis distributes groups of layers across nodes in a pipeline. The DP axis replicates the entire pipeline across independent groups that process different data shards. A fourth axis, Context Parallelism (CP), splits long sequences across GPUs for models with extended context windows. Each axis addresses a different bottleneck: TP reduces per-layer memory, PP reduces total model memory per device, DP scales throughput, and CP handles sequence-length memory.

1.1 The Four Parallelism Axes

A typical Megatron-LM configuration specifies four parallelism degrees: TP, PP, DP, and (optionally) CP. The total GPU count equals $\text{TP} \times \text{PP} \times \text{DP} \times \text{CP}$. Choosing the right combination depends on model size, sequence length, cluster topology, and interconnect bandwidth.

Tensor Parallelism (TP) splits individual matrix multiplications across GPUs. For a transformer's attention and MLP layers, TP partitions the weight matrices along the hidden dimension. This requires an all-reduce after each layer, so it works best within a node where NVLink provides 900 GB/s bandwidth (on H100 nodes). Typical TP degree: 2, 4, or 8 within a single node.

Pipeline Parallelism (PP) assigns contiguous groups of transformer layers to different devices. Megatron uses an interleaved 1F1B (one-forward-one-backward) schedule to minimize pipeline bubbles. The bubble fraction is approximately $\frac{p - 1}{m}$ where $p$ is the number of pipeline stages and $m$ is the number of microbatches. With enough microbatches, bubble overhead drops below 5%.

Data Parallelism (DP) replicates the model (or model shard) and distributes data batches across replicas. When combined with ZeRO Stage 1 (optimizer state sharding), DP scales throughput linearly with minimal memory overhead. Megatron-Core integrates distributed data parallelism with optional gradient compression.

Context Parallelism (CP) splits the input sequence across GPUs along the sequence dimension, distributing the attention computation. This is essential for models trained on sequences longer than 8K tokens, where the quadratic attention memory becomes a bottleneck even with FlashAttention. CP uses ring attention or similar algorithms to exchange KV blocks between GPUs.

1.2 Practical Configurations for Common Model Sizes

The following table shows recommended parallelism configurations for training common model sizes on clusters of NVIDIA H100 GPUs. These represent starting points; actual configurations require profiling on your specific hardware and interconnect.

Table 6.8.1: Recommended Megatron parallelism configurations for H100 clusters
Model SizeTPPPDPCPGPUsMicro-batchGlobal Batch
7B11641644256
13B216411282128
70B8416151211024
70B (128K ctx)848410241512
405B816321409612048
405B (128K ctx)816164819211024
Practical Example

Llama 3.1 405B Training Configuration. Meta's Llama 3.1 405B was trained on 16,384 H100 GPUs using TP=8, PP=16, DP=128 (approximately). Each node of 8 GPUs handled one TP group, 16 nodes formed a pipeline, and 128 pipeline replicas ran in parallel. The training consumed approximately 30.8 million GPU-hours over several months. Achieving 38-43% Model FLOPS Utilization (MFU) on this scale required careful overlap of computation and communication, interleaved pipeline schedules, and a custom network topology to minimize cross-rack traffic.

1.3 Megatron-Core Configuration

Code Fragment 6.8.1 shows how to configure multi-axis parallelism in Megatron-Core. The TransformerConfig object specifies the model architecture, and parallel_state initializes the process groups for each parallelism axis.

# Code Fragment 6.8.1: Megatron-Core multi-axis parallelism setup
import megatron.core as mcore
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core import parallel_state

# Initialize distributed process groups
parallel_state.initialize_model_parallel(
 tensor_model_parallel_size=8, # TP: split layers within node
 pipeline_model_parallel_size=4, # PP: distribute layers across nodes
 context_parallel_size=1, # CP: sequence-dimension splitting
 # DP is inferred: world_size / (TP * PP * CP)
)

# Configure transformer architecture
config = TransformerConfig(
 num_layers=80, # 70B-class model
 hidden_size=8192,
 num_attention_heads=64,
 num_query_groups=8, # GQA: 8 KV heads
 ffn_hidden_size=28672,
 max_position_embeddings=8192,
 bf16=True, # BF16 mixed precision
 # Pipeline parallelism schedule
 pipeline_dtype="bfloat16",
 num_microbatches=16, # Reduce bubble fraction
 overlap_p2p_comm=True, # Overlap pipeline communication
 # Activation checkpointing to reduce memory
 recompute_granularity="selective",
 recompute_method="uniform",
)

# Build the model with Megatron-Core's GPT architecture
model = mcore.models.gpt.GPTModel(
 config=config,
 transformer_layer_spec=mcore.models.gpt.gpt_layer_specs.get_gpt_layer_spec(),
 vocab_size=128256,
 max_sequence_length=8192,
)
Code Fragment 6.8.12: Code Fragment 6.8.1: Megatron-Core multi-axis parallelism setup

1.4 MoE Load Balancing

Mixture-of-Experts (MoE) models (such as Mixtral 8x7B, DeepSeek-V3, and Arctic) add another layer of parallelism complexity. In an MoE transformer, each token is routed to a subset of experts (typically 2 out of 8 or 16). Without load balancing, some experts receive disproportionately more tokens, creating stragglers that slow the entire training step.

Megatron-Core implements Expert Parallelism (EP), where experts are distributed across GPUs and tokens are communicated via all-to-all collectives. Load balancing is enforced through an auxiliary loss term added to the training objective:

$$\mathcal{L}_{\text{aux}} = \alpha \cdot N \sum_{i=1}^{N} f_i \cdot p_i$$

where $N$ is the number of experts, $f_i$ is the fraction of tokens routed to expert $i$, $p_i$ is the average routing probability for expert $i$, and $\alpha$ is a balancing coefficient (typically 0.01). This loss encourages uniform token distribution across experts without interfering with the primary language modeling objective.

Fun Fact

DeepSeek-V3 (671B total, 37B active) achieved competitive performance with dense models while using only 1/18th of the parameters per forward pass. The trick was not just the MoE architecture but an auxiliary-loss-free load balancing strategy that used per-expert bias terms instead of the traditional auxiliary loss, avoiding the subtle performance degradation that auxiliary losses can cause.

2. Compiler and Kernel Optimization for LLM Training

Raw parallelism strategy determines how work is distributed across GPUs. Compiler and kernel optimization determine how efficiently each GPU executes its share of the work. The gap between theoretical peak FLOPS and actual training throughput (MFU) is where compiler and kernel optimizations make their impact.

2.1 torch.compile for LLM Training Loops

PyTorch 2.x introduced torch.compile, which uses TorchDynamo (a Python bytecode analyzer) and TorchInductor (a code generation backend) to automatically fuse operations and generate optimized GPU kernels. For LLM training, torch.compile can provide 10-30% throughput improvements by fusing elementwise operations, reducing kernel launch overhead, and optimizing memory access patterns.

# Code Fragment 6.8.2: Applying torch.compile to an LLM training loop
import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
 "meta-llama/Llama-3.1-8B",
 torch_dtype=torch.bfloat16,
 attn_implementation="flash_attention_2",
)

# Compile the model for training
# mode="reduce-overhead" minimizes kernel launch overhead
# fullgraph=True ensures the entire forward+backward is captured
compiled_model = torch.compile(
 model,
 mode="reduce-overhead",
 fullgraph=True,
 backend="inductor", # TorchInductor: generates Triton kernels
)

# Training loop proceeds normally; compiled model is a drop-in replacement
optimizer = torch.optim.AdamW(compiled_model.parameters(), lr=3e-4)
for batch in dataloader:
 outputs = compiled_model(**batch)
 loss = outputs.loss
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()
Code Fragment 6.8.11: Code Fragment 6.8.2: Applying torch.compile to an LLM training loop
Key Insight

TorchInductor generates Triton kernels under the hood. Triton is a Python-based language for writing GPU kernels that compiles to PTX (NVIDIA) or AMDGPU IR (AMD). This means torch.compile benefits are portable across GPU vendors. The Inductor backend analyzes the computation graph, identifies fusion opportunities (such as combining layer norm, dropout, and residual addition into a single kernel), and generates Triton code that eliminates redundant memory reads and writes.

2.2 FlashAttention Kernel Selection

FlashAttention (Dao et al., 2022; Dao, 2023) rewrites the attention computation to minimize HBM (High Bandwidth Memory) accesses by tiling the attention matrix and fusing softmax into the GEMM loop. FlashAttention-2 improved throughput by 2x over the original through better work partitioning across GPU thread blocks. FlashAttention-3 (2024) further optimizes for H100 Hopper architecture features, including FP8 tensor cores and asynchronous operations.

The choice of attention kernel depends on the hardware generation and sequence length:

Table 6.8.2: Attention kernel selection guide
HardwareSequence LengthRecommended KernelNotes
A100< 8KFlashAttention-2Standard choice for Ampere GPUs
A1008K-128KFlashAttention-2 + CPCombine with context parallelism
H100< 8KFlashAttention-3Exploits Hopper async features
H1008K-128KFlashAttention-3 + Ring AttentionDistributed attention across devices
H100 (FP8)AnyFlashAttention-3 FP82x throughput, requires careful scaling

2.3 FP8 and FP4 Training with Transformer Engine

NVIDIA's Transformer Engine library enables mixed-precision training with FP8 (8-bit floating point) on Hopper and later architectures. FP8 training nearly doubles the throughput of BF16 training by exploiting the H100's FP8 tensor cores, which deliver 3958 TFLOPS compared to 1979 TFLOPS for BF16 (Micikevicius et al., 2022).

FP8 training uses two formats: E4M3 (4 exponent bits, 3 mantissa bits) for forward activations and weights, and E5M2 (5 exponent bits, 2 mantissa bits) for gradients, which need the larger dynamic range. A per-tensor scaling factor is maintained and updated each iteration to keep values within the FP8 representable range.

# Code Fragment 6.8.3: Enabling FP8 training with Transformer Engine
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# FP8 recipe: delayed scaling updates the scale factor
# based on the maximum absolute value from the previous iteration
fp8_recipe = DelayedScaling(
 margin=0,
 fp8_format=Format.HYBRID, # E4M3 for fwd, E5M2 for bwd
 amax_history_len=16, # Track max values over 16 iterations
 amax_compute_algo="max", # Use max of history for scaling
)

# Replace nn.Linear with Transformer Engine's FP8-aware version
# in a transformer block
class TransformerBlock(te.pytorch.module.TransformerLayer):
 def __init__(self, hidden_size, num_heads):
 super().__init__(
 hidden_size=hidden_size,
 ffn_hidden_size=4 * hidden_size,
 num_attention_heads=num_heads,
 # FP8 is enabled via the context manager, not here
 )

# During training, wrap the forward pass in the fp8 context
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
 output = model(input_ids, labels=labels)
 loss = output.loss
loss.backward()
Code Fragment 6.8.10: Code Fragment 6.8.3: Enabling FP8 training with Transformer Engine
Practical Example

FP8 training throughput gains. When training a Llama-3.1-8B model on a single DGX H100 node (8 GPUs), switching from BF16 to FP8 with Transformer Engine increased training throughput from approximately 420 tokens/sec/GPU to approximately 730 tokens/sec/GPU, a 74% improvement. Loss curves remained virtually identical, with the FP8 run's final loss within 0.02% of the BF16 baseline. The key to this stability is the delayed scaling mechanism, which adapts the per-tensor scaling factors smoothly rather than reacting to outliers.

3. Elastic and Fault-Tolerant Training

A multi-week LLM pretraining run on thousands of GPUs will experience hardware failures. At Meta's scale (16,384 GPUs for Llama 3.1), the team reported encountering over 400 unexpected interruptions during a 54-day training run. At this scale, fault tolerance is not optional; it is a fundamental requirement.

3.1 Torch Distributed Elastic (TorchElastic)

TorchElastic (accessed via torchrun) provides elastic training capabilities for PyTorch distributed jobs. Unlike standard torch.distributed.launch, TorchElastic can handle node failures and additions without requiring a complete restart of the training job. Key features include:

# Code Fragment 6.8.4: Launching an elastic LLM training job with torchrun
# --nproc_per_node: GPUs per node
# --nnodes: min:max node range for elasticity
# --rdzv_backend: rendezvous backend (c10d or etcd)
# --rdzv_endpoint: rendezvous coordinator address
# --max_restarts: number of times to restart on failure

torchrun \
 --nproc_per_node=8 \
 --nnodes=4:8 \
 --rdzv_backend=c10d \
 --rdzv_endpoint=master-node:29400 \
 --max_restarts=3 \
 --monitor_interval=5 \
 train_llm.py \
 --model_size=7b \
 --checkpoint_dir=/shared/checkpoints \
 --resume_from_latest

The --nnodes=4:8 flag specifies that the job can run with anywhere from 4 to 8 nodes. If a node fails, training continues (after a brief restart) with the remaining nodes, as long as at least 4 are available. The data parallelism degree adjusts automatically.

3.2 Failure Detection and Recovery Strategies

Production LLM training systems employ multiple layers of failure detection and recovery:

Table 6.8.3: Failure types and recovery strategies in LLM training
Failure TypeDetectionRecoveryTypical Frequency
GPU memory error (ECC)NVIDIA DCGM monitoringDrain node, replace, restart1-3 per day at 4K+ GPUs
Network link failureNCCL timeout + health checkRe-route or restart affected groupWeekly at 4K+ GPUs
NaN in lossLoss monitoring hookRoll back to previous checkpointVaries by model/data
Straggler nodeIteration time varianceDrain slow node, restartDaily at scale
Full node crashHeartbeat timeoutElastic re-rendezvous2-5 per week at 4K+ GPUs
Fun Fact

Meta's Llama 3.1 training report revealed that their automated failure recovery system handled the vast majority of the 419 interruptions during the 54-day run. The median time to recover was under 10 minutes, meaning the system spent less than 5% of total wall-clock time on recovery. The team built a custom "training supervisor" that could diagnose the failure type, decide whether to restart on the same nodes or exclude the failed node, and resume from the latest checkpoint, all without human intervention.

4. Distributed Checkpointing

Checkpointing is the primary mechanism for both fault recovery and experiment management. For large LLM training runs, naive checkpointing (gathering all state to a single writer) is prohibitively slow. A 405B model's full state (parameters + optimizer + gradients) can exceed 8 TB. Writing this to storage as a single file would take tens of minutes, during which training is blocked.

4.1 PyTorch Distributed Checkpoint (DCP)

PyTorch Distributed Checkpoint (DCP), introduced in PyTorch 2.1, provides a native solution for saving and loading sharded model state. Each rank writes its own shard independently, and DCP handles the coordination and metadata management. Key properties include:

# Code Fragment 6.8.5: Saving and loading with PyTorch DCP
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
 get_model_state_dict,
 get_optimizer_state_dict,
 set_model_state_dict,
 set_optimizer_state_dict,
 StateDictOptions,
)

# --- Saving a sharded checkpoint ---
def save_checkpoint(model, optimizer, step, checkpoint_dir):
 """Save a reshardable distributed checkpoint."""
 state_dict = {
 "model": get_model_state_dict(model),
 "optimizer": get_optimizer_state_dict(model, optimizer),
 "step": step,
 }
 dcp.save(
 state_dict=state_dict,
 storage_writer=dcp.FileSystemWriter(
 f"{checkpoint_dir}/step_{step}",
 single_file_per_rank=True, # One file per rank for parallel I/O
 ),
 )

# --- Loading with potentially different parallelism ---
def load_checkpoint(model, optimizer, checkpoint_dir, step):
 """Load checkpoint; DCP handles resharding automatically."""
 state_dict = {
 "model": get_model_state_dict(model),
 "optimizer": get_optimizer_state_dict(model, optimizer),
 }
 dcp.load(
 state_dict=state_dict,
 storage_reader=dcp.FileSystemReader(
 f"{checkpoint_dir}/step_{step}"
 ),
 )
 # Apply loaded state back to model and optimizer
 set_model_state_dict(
 model, state_dict["model"],
 options=StateDictOptions(strict=True),
 )
 set_optimizer_state_dict(
 model, optimizer, state_dict["optimizer"],
 )
 return state_dict.get("step", 0)
Code Fragment 6.8.9: Code Fragment 6.8.5: Saving and loading with PyTorch DCP

4.2 Async Checkpointing and Staging

For multi-week training runs, even a few minutes of blocking per checkpoint adds up to hours of wasted GPU time. Async checkpointing eliminates this bottleneck by overlapping the storage write with the next training iteration:

  1. Snapshot to CPU: Copy the model and optimizer state from GPU to pinned CPU memory. This takes seconds for a 70B model.
  2. Continue training: The GPU immediately begins the next training iteration.
  3. Background write: A separate thread writes the CPU snapshot to the parallel file system or object storage (S3, GCS).
  4. Verification: Once the write completes, the checkpoint metadata is atomically updated to mark the new checkpoint as valid.

4.3 Checkpoint Frequency Trade-offs

Choosing checkpoint frequency involves a trade-off between recovery time and storage cost. If you checkpoint every $C$ iterations and a failure occurs, you lose on average $C/2$ iterations of work. For a training run with $N$ total iterations and a mean time between failures (MTBF) of $F$ iterations, the expected wasted compute is:

$$\text{Wasted fraction} \approx \frac{C}{2F} + \frac{t_{\text{ckpt}}}{t_{\text{iter}} \cdot C}$$

where $t_{\text{ckpt}}$ is the checkpoint duration and $t_{\text{iter}}$ is the iteration time. The first term represents lost work from rollback; the second represents time spent checkpointing. Minimizing their sum gives the optimal checkpoint interval. With async checkpointing, $t_{\text{ckpt}}$ drops to near zero (only the GPU-to-CPU copy time), making more frequent checkpointing practical.

Practical Example

Checkpoint strategy for a 70B pretraining run. A team training a 70B model on 512 H100 GPUs with an MTBF of approximately 2,000 iterations (about 8 hours) uses async DCP checkpoints every 200 iterations (about 50 minutes). Each async checkpoint takes 15 seconds of blocking time (GPU-to-CPU copy) plus 3 minutes of background I/O. With this configuration, the expected wasted compute from failures is approximately 5% (100 iterations of rollback out of 2,000), and the checkpoint overhead is less than 0.5%. The team also saves "milestone" checkpoints every 10,000 iterations to a separate object storage bucket for long-term retention and experiment tracking.

5. Streaming Dataset Pipelines

LLM pretraining datasets are too large to fit in local storage on training nodes. The Llama 3 training corpus was approximately 15 trillion tokens, and the full tokenized dataset can exceed 30 TB. Streaming dataset libraries solve this by loading data on-the-fly from cloud object storage, eliminating the need to pre-download the entire dataset.

5.1 Mosaic Streaming (StreamingDataset)

MosaicML's streaming library (now part of the Databricks ecosystem) provides a streaming dataloader purpose-built for distributed LLM training. Key properties include:

# Code Fragment 6.8.6: Streaming dataloader for distributed LLM training
from streaming import StreamingDataset, StreamingDataLoader

# Define the dataset with multiple data sources and mixing proportions
dataset = StreamingDataset(
 streams=[
 {
 "remote": "s3://training-data/web-text/",
 "local": "/local-nvme/cache/web-text/",
 "proportion": 0.50, # 50% web text
 },
 {
 "remote": "s3://training-data/code/",
 "local": "/local-nvme/cache/code/",
 "proportion": 0.25, # 25% code
 },
 {
 "remote": "s3://training-data/books/",
 "local": "/local-nvme/cache/books/",
 "proportion": 0.15, # 15% books
 },
 {
 "remote": "s3://training-data/conversations/",
 "local": "/local-nvme/cache/conversations/",
 "proportion": 0.10, # 10% conversational data
 },
 ],
 shuffle=True,
 shuffle_seed=42, # Deterministic shuffling
 num_canonical_nodes=64, # Controls sharding granularity
 batch_size=4, # Per-device micro-batch size
)

# StreamingDataLoader handles distributed rank assignment automatically
dataloader = StreamingDataLoader(
 dataset=dataset,
 batch_size=4,
 num_workers=8,
 pin_memory=True,
 prefetch_factor=4, # Prefetch 4 batches ahead
)

# Resume from a specific sample index after failure
dataloader.dataset.state_dict() # Save state
# ... after restart ...
dataloader.dataset.load_state_dict(saved_state) # Resume exactly
Code Fragment 6.8.8: Code Fragment 6.8.6: Streaming dataloader for distributed LLM training

5.2 WebDataset

WebDataset takes a different approach: it stores data as tar archives containing individual samples, and streams them sequentially. This format is particularly efficient for multimodal training (image-text pairs) where each sample consists of multiple files. For text-only LLM training, WebDataset works well but offers less sophisticated mixing and resumability compared to Mosaic Streaming.

5.3 Data Integrity and Validation

At the scale of trillions of tokens, data corruption can silently degrade model quality. Production training pipelines include integrity checks at multiple levels:

Key Insight

The choice between Mosaic Streaming and WebDataset often comes down to the training framework. Mosaic Streaming integrates tightly with the Composer/MosaicML ecosystem and offers superior resumability and mixing. WebDataset is more framework-agnostic and widely used in multimodal pipelines. Both support deterministic iteration, which is critical for reproducing training runs and debugging data-related issues. For LLM-specific pretraining, Mosaic Streaming's built-in curriculum learning support (dynamically changing data proportions during training) gives it an edge for teams that want to upweight code or instruction data in later training phases, following the approach used in Llama 3 and similar models.

6. Putting It All Together: Production Training Architecture

A production LLM training system integrates all of the components discussed in this section. The following diagram shows how they fit together for a typical 70B+ model pretraining run:

Figure 6.8.1: Production LLM training architecture. Data flows from object st...
Figure 6.8.1: Production LLM training architecture. Data flows from object storage through streaming loaders. Megatron-Core handles multi-axis parallelism with compiler optimizations. TorchElastic provides fault tolerance. DCP manages distributed checkpoints. The monitoring layer tracks training health across all components.
Self-Check Questions
  1. You have 64 GPUs with 80 GB each and a 30B parameter model. Roughly how much memory do the model weights, gradients, and AdamW states require in FP32? Which parallelism strategies (TP, PP, DP, FSDP) would you combine, and why?
  2. Explain why pipeline parallelism introduces "bubble" time and how interleaving micro-batches reduces it.
  3. A training run on 512 GPUs loses one node every 8 hours on average. Describe how TorchElastic and asynchronous distributed checkpointing keep the run progressing despite these failures.

✅ Key Takeaways

Research Frontier

Several active research directions aim to push training efficiency further. Fully asynchronous pipeline parallelism (e.g., PipeDream-2BW, zero-bubble PP) seeks to eliminate pipeline bubble overhead entirely. Communication-computation overlap research focuses on hiding all-reduce latency behind computation through fine-grained scheduling. FP4 training is an emerging area where 4-bit floating point could double throughput again on next-generation hardware, though maintaining training stability at this precision remains challenging. Heterogeneous training explores using a mix of GPU generations (A100 + H100) in the same cluster, which requires adaptive load balancing to account for different compute speeds.

Exercises

Exercise 6.8.1: Parallelism configuration Conceptual

You have a cluster of 256 H100 GPUs (32 nodes with 8 GPUs each) and want to train a 13B parameter model with a 4K context length. Design a parallelism configuration (TP, PP, DP, CP) and justify each choice. What changes if you increase the context to 32K tokens?

Answer Sketch

For 13B at 4K context: TP=2 (model fits in 2 GPUs per layer), PP=1 (small enough for single pipeline stage), DP=128 (256/2/1), CP=1 (short context). For 32K context: increase TP=4 to handle larger activation memory, consider CP=2, giving DP=256/(4*1*2)=32. The key constraint is that TP should stay within a single node for NVLink bandwidth, and CP helps when attention memory becomes the bottleneck at longer sequences.

Exercise 6.8.2: Checkpoint frequency Analysis

A 70B model training run has an MTBF of 3,000 iterations. Each iteration takes 30 seconds. A blocking checkpoint takes 5 minutes, and an async checkpoint takes 10 seconds of GPU blocking time. Calculate the optimal checkpoint frequency for both blocking and async cases. How much total GPU-time is saved by using async checkpointing over a 100,000-iteration run?

Answer Sketch

For blocking: optimize $C/6000 + 300/(30 \cdot C)$, giving $C \approx \sqrt{300 \cdot 6000/30} \approx 245$ iterations. For async: $C/6000 + 10/(30 \cdot C)$, giving $C \approx \sqrt{10 \cdot 6000/30} \approx 45$ iterations. Over 100K iterations: blocking checkpoints ~408 times at 5 min each = 34 hours blocking. Async checkpoints ~2,222 times at 10 sec = 6.2 hours. Net saving approximately 28 hours of GPU time across the cluster.

Exercise 6.8.3: MoE load balancing Analysis

Explain why unbalanced expert routing creates a training bottleneck in MoE models. Derive the auxiliary loss formula and explain how the balancing coefficient $\alpha$ affects the trade-off between load balance and model quality.

Answer Sketch

In MoE, each training step waits for all experts to finish. If expert $i$ receives 3x more tokens than expert $j$, the step time is determined by the slowest (most loaded) expert, wasting compute on the underloaded ones. The auxiliary loss $\alpha \cdot N \sum f_i \cdot p_i$ penalizes routing probability ($p_i$) that concentrates tokens ($f_i$) on few experts. Too high $\alpha$ forces uniform routing but prevents specialization (experts cannot differentiate). Too low $\alpha$ allows imbalance. Typical values of $\alpha = 0.01$ to $0.1$ provide a good compromise.

Exercise 6.8.4: Streaming data pipeline Coding

Write a configuration for a Mosaic StreamingDataset that mixes three data sources: 60% web text, 30% code, and 10% instruction-following data. Include local caching, deterministic shuffling, and resumability support. Explain how the data proportions would change if you were implementing a curriculum that upweights code data in the final 20% of training.

Answer Sketch

Configure three streams with proportions [0.6, 0.3, 0.1], shuffle_seed for determinism, local cache paths on NVMe. For curriculum: at 80% completion, change proportions to [0.4, 0.5, 0.1] to upweight code. Mosaic Streaming supports dynamic proportion changes via the StreamingDataset API. Save dataset.state_dict() before changing proportions to ensure exact resumability.

What's Next?

With a solid understanding of how LLMs are pre-trained at scale, including scaling laws, data curation, and distributed training infrastructure, we turn to the models themselves. In Chapter 7: Modern LLM Landscape and Model Internals, we survey the major model families (GPT, Llama, Gemini, Mistral, and more), compare their architectures and training recipes, and explore what makes each one unique.

References & Further Reading
Megatron-LM and Multi-Axis Parallelism

Shoeybi, M., Patwary, M., Puri, R., et al. (2020). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv preprint arXiv:1909.08053.

The foundational Megatron-LM paper introducing efficient intra-layer model parallelism (tensor parallelism) for transformer models. Demonstrated that careful parallelism design enables training models up to 8.3B parameters with high throughput on GPU clusters.

📄 Paper

Narayanan, D., Shoeybi, M., Casper, J., et al. (2021). Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM. SC '21.

Extends Megatron-LM with pipeline parallelism and the interleaved 1F1B schedule. Provides the theoretical framework for combining TP, PP, and DP, and demonstrates efficient training of trillion-parameter models.

📄 Paper

Dubey, A., Jauhri, A., Pandey, A., et al. (2024). The Llama 3 Herd of Models. arXiv preprint arXiv:2407.21783.

Meta's detailed training report for Llama 3 and 3.1 (up to 405B parameters). Provides invaluable production details on parallelism configurations, failure rates, recovery procedures, and training infrastructure at scale.

📄 Paper
FlashAttention and Kernel Optimization

Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024.

Improves FlashAttention throughput by 2x through better work partitioning across GPU thread blocks and warps. The standard attention kernel for all modern LLM training.

📄 Paper

Dao, T., Fu, D. Y., Ermon, S., Rudra, A., Re, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

The original FlashAttention paper that introduced IO-aware attention computation, reducing HBM accesses from quadratic to linear in sequence length. Foundational for all subsequent attention kernel optimizations.

📄 Paper
FP8 Training

Micikevicius, P., Stosic, D., Burgess, N., et al. (2022). FP8 Formats for Deep Learning. arXiv preprint arXiv:2209.05433.

Defines the E4M3 and E5M2 FP8 formats and demonstrates that deep learning training converges with FP8 precision when combined with appropriate per-tensor scaling strategies. The basis for NVIDIA Transformer Engine.

📄 Paper
Distributed Checkpointing and Fault Tolerance

PyTorch Team (2024). PyTorch Distributed Checkpoint (DCP) Documentation.

Official documentation for PyTorch's distributed checkpointing API, covering sharded saves, resharding on load, and async staging. The reference implementation for production LLM checkpointing.

📖 Documentation

PyTorch Team (2024). TorchElastic (torchrun) Documentation.

Official documentation for TorchElastic, covering elastic launch, rendezvous protocols, and fault tolerance for distributed training jobs.

📖 Documentation
Streaming Datasets

MosaicML (2023). Mosaic StreamingDataset: Fast, Accurate Streaming from Cloud Storage.

The Mosaic Streaming library for deterministic, resumable data loading from object storage. Supports data mixing, curriculum learning, and efficient distributed sharding for LLM pretraining at scale.

💻 Library

Aizman, A., Maltby, G., Breuel, T. (2023). WebDataset: High-Performance Dataset I/O.

A tar-based streaming dataset format optimized for sequential I/O patterns. Widely used in multimodal training pipelines for efficiently loading image-text pairs from cloud storage.

💻 Library