Saving, Loading, and Deployment

Section E.10

A trained model is not yet a deployed product. Between the last training step and the first production request there are decisions about serialization format, export target, latency budget, and hardware. This section covers the PyTorch-side of that transition: how to save a model so it can be loaded reliably, how to export to ONNX for cross-framework deployment, how TorchScript fits in (and when to skip it), and the brief overview of mobile and edge paths. Production serving systems (vLLM, TensorRT-LLM, llama.cpp) that consume these exports are covered in Chapter 9.

state_dict vs Full Pickle

PyTorch supports two ways of serializing a model. The first is torch.save(model, path), which pickles the entire model object. The second is torch.save(model.state_dict(), path), which serializes only the tensor values. The two-line save and load idiom looks identical, but the implications differ considerably.

Pickling the whole model is brittle. The pickle file contains a reference to the Python class that defined the model; loading the file requires that exact class to be importable under exactly the same module path. Refactoring the code (moving a class to a new module, renaming an attribute) breaks every checkpoint that depends on the old layout. The pickle also bundles the entire nn.Module machinery, which is unnecessary because the architecture is already in the source code.

Saving only the state_dict is portable. The file is a flat dictionary of parameter tensors keyed by qualified name. Loading requires reconstructing the architecture in code and calling model.load_state_dict(torch.load(path)). The reconstruction step is the price for portability: the same state dict loads into an unchanged class today and into a refactored class tomorrow, as long as parameter names line up.

import torch

# Save: state dict only, never the full model object.
torch.save(model.state_dict(), "checkpoint.pt")

# Load: reconstruct the architecture, then inject weights.
model = NeuralNetwork(num_inputs=50, num_outputs=10)
state = torch.load("checkpoint.pt", map_location="cpu", weights_only=True)
model.load_state_dict(state, strict=True)
model.eval()
Output: (no stdout; model is loaded with its trained parameters)
Code Fragment E.10.1: The portable save and load pattern. The map_location="cpu" argument allows the file to be loaded on hosts that lack the GPU the model was trained on. The weights_only=True flag (default in recent PyTorch) prevents arbitrary code execution at load time.
Warning: weights_only=True Is a Security Boundary

The default behavior of torch.load in older PyTorch versions executed arbitrary Python code in the pickle file. A maliciously crafted checkpoint could, on load, exfiltrate data or install a backdoor. Recent PyTorch defaults to weights_only=True, which permits only tensor and basic type deserialization; this should never be turned off when loading untrusted checkpoints from the internet. For checkpoints that legitimately contain custom objects (training state with custom optimizer hyperparameters, for example), explicitly allow-list those classes with torch.serialization.add_safe_globals rather than disabling the safety check.

Library Shortcut: safetensors for Safe, Fast Checkpoints

The safetensors format (now the default on the Hugging Face Hub) replaces pickle entirely: zero-copy memory mapping, deterministic layout, and no code execution at load time. The save and load API is a drop-in for state_dict serialization.

from safetensors.torch import save_file, load_file

save_file(model.state_dict(), "model.safetensors")
state = load_file("model.safetensors", device="cpu")
model.load_state_dict(state)

Atomic Saves Revisited

Section E.5 introduced the temp-file-and-rename pattern for atomic checkpoints. The pattern matters even more for the "best so far" model that production inference will load: if a save is interrupted by a crash or preemption, a half-written file at the canonical path will cause every subsequent load to fail. os.replace performs an atomic rename on both POSIX and Windows NTFS, so the canonical path either contains the previous valid checkpoint or the new valid checkpoint, never a corrupt mid-write file.

ONNX Export

ONNX (Open Neural Network Exchange) is a cross-framework graph format. Exporting a PyTorch model to ONNX makes it loadable by inference runtimes that do not depend on Python or PyTorch: ONNX Runtime, TensorRT, OpenVINO, CoreML, and many edge-device runtimes. For deployment outside the PyTorch ecosystem, ONNX is the lingua franca.

import torch

model = NeuralNetwork(num_inputs=50, num_outputs=10).eval()
dummy = torch.randn(1, 50)                  # example input for tracing

# Modern API (PyTorch 2.x). Uses the dynamo-based exporter, which
# handles control flow more robustly than the legacy tracer.
torch.onnx.export(
    model,
    (dummy,),
    "model.onnx",
    input_names=["features"],
    output_names=["logits"],
    dynamic_axes={"features": {0: "batch"}, "logits": {0: "batch"}},
    opset_version=18,
)

# Optional: verify the exported model.
import onnx
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
Output: (no stdout; model.onnx is written and validated)
Code Fragment E.10.2: ONNX export with a dynamic batch dimension. The export is driven by a representative input; for models with data-dependent control flow, the dynamo-based exporter handles many cases the legacy tracer could not.

The dynamic_axes argument marks which dimensions of which inputs are allowed to vary at inference time. Without it, the exported ONNX model would be locked to whatever shape the dummy input had, which is rarely what production needs. Common patterns are dynamic batch (always), dynamic sequence length (for NLP), and dynamic image size (for vision models that accept variable resolutions).

Practical Example: Validating an ONNX Export

The export succeeds quietly even if the ONNX model produces different outputs than the source. After exporting, always run the same input through both the PyTorch model and the ONNX model (via onnxruntime.InferenceSession) and verify with torch.allclose(torch_out, onnx_out, atol=1e-4). Mismatches usually trace back to unsupported ops, missing dynamic_axes, or precision drift between PyTorch and ONNX Runtime. Catching the mismatch at export time is much cheaper than catching it after the model is in production.

TorchScript: Trace vs Script

TorchScript was PyTorch's original answer to "deploy without Python." It comes in two flavors. torch.jit.trace(model, example_input) records the operations performed on the example input and bakes them into a script. Tracing is simple but unfaithful: any data-dependent control flow (a Python if branching on tensor values) is locked to the path taken by the example. torch.jit.script(model) instead parses the model's Python source code and compiles it to TorchScript, faithfully preserving control flow at the cost of supporting only a TorchScript-compatible subset of Python.

TorchScript is largely superseded by torch.compile for performance and by ONNX export for cross-framework deployment. Its remaining use cases are mobile deployment via PyTorch Mobile (which loads TorchScript) and a few legacy production systems. For new code, prefer torch.compile + state_dict save for PyTorch-resident deployment, and ONNX export for everything else.

Mobile and Edge Deployment

Three paths exist for running PyTorch models on mobile and edge devices.

Regardless of runtime, mobile deployment usually requires quantization to int8 or smaller to meet memory and latency budgets. PyTorch's torch.ao.quantization toolkit handles post-training quantization and quantization-aware training; the resulting model can then be exported to whichever runtime the device supports. The quantization tradeoffs are covered in detail in Section 9.1.

Key Insight: Eager-Mode Inference Is Often Enough

For server-side deployment of medium-size models, plain PyTorch in eager mode (with model.eval() and torch.inference_mode()) is often the simplest, fastest option. Adding torch.compile typically gives a free 1.5 to 2x. Beyond that, the marginal complexity of TorchScript, ONNX, or TensorRT pays off only when latency budgets are tight, when deployment is to a Python-free runtime, or when the model is large enough that compiler-specific kernel fusion matters. Starting with the simplest path and adding complexity only when measurements demand it saves significant engineering time.

Sharing via the Hugging Face Hub

For models that will be shared publicly, the Hugging Face Hub provides versioned storage, model cards, and integration with the transformers library. The pattern is to subclass PreTrainedModel, save with model.save_pretrained(path), and push with model.push_to_hub(repo_id). Loading from the Hub is then a one-liner: AutoModel.from_pretrained(repo_id). For PyTorch-only models that do not need transformers, the Hub still works through huggingface_hub's upload_file API; the resulting URL can be downloaded directly with hf_hub_download on any machine.

Key Insight

For PyTorch-resident deployment, save state_dict (never the full model), reconstruct the architecture at load time, and write atomically. For cross-framework deployment, export to ONNX with dynamic_axes set for the dimensions that vary, and always validate the export against the PyTorch source. For mobile and edge, use ExecuTorch for new code and PyTorch Mobile for maintenance. Skip TorchScript unless legacy systems require it. The simplest deployment path that meets the latency budget is almost always the right one; reach for compilers and quantization only when measurements justify the complexity.

Exercise E.10.1: Round-Trip via ONNX with Dynamic Batch Size

Objective. Export a model to ONNX, run it in onnxruntime, and verify numerical equivalence to the PyTorch source.

Task. Train a small image classifier on FashionMNIST. Export via torch.onnx.export(model, dummy_input, "model.onnx", dynamic_axes={"input": {0: "batch"}}). Load with onnxruntime.InferenceSession, run with batch sizes 1, 16, and 64, and confirm with torch.allclose(torch_out, ort_out, atol=1e-4) for all three. Then time both paths over 100 batches and report the speed-up (or slow-down).

Hint. If allclose fails, the most common culprits are operators ONNX does not yet support and BatchNorm in training mode. Always call model.eval() before export.

Exercise E.10.2: Atomic Save plus Cross-Device Load

Objective. Demonstrate the two save-and-load disciplines that prevent the most common production bugs.

Task. Save a small model trained on CUDA via the atomic tmp + os.replace pattern. Confirm the file exists. Then load it back with torch.load(path, map_location="cpu") and verify the parameters land on CPU. Finally, write a 5-line "save the previous best to a backup before overwriting" wrapper and demonstrate that a simulated crash (raise an exception between save and rename) leaves the previous best intact.

Stretch. Push the model to the Hugging Face Hub with huggingface_hub and reload it from a fresh process via hf_hub_download. Confirm the SHA matches.

Further Reading

Deployment References

PyTorch Tutorial: Saving and Loading Models. Covers state_dict vs full-model patterns, cross-device loading, and the optimizer checkpointing idioms. The starting point for any serialization decision.
PyTorch Documentation: ONNX Exporter. The current export API, including the dynamo-based exporter that replaces the legacy tracer for most use cases.
ExecuTorch Documentation. The on-device runtime: architecture, export workflow, and integration guides for iOS, Android, and embedded targets.
Hugging Face Hub Documentation. Programmatic upload, download, and versioning of models, datasets, and spaces. The pragmatic option for sharing PyTorch checkpoints.