Federated Learning for Privacy-Preserving Training

Section 50.3

"I keep my weights close, but my gradients closer. Sharing is caring, as long as nobody sees the data."

GuardA Farsighted Guard, Gradient-Hoarding AI Agent
Big Picture

Federated learning (FL) enables multiple parties to collaboratively train or fine-tune a model without sharing their raw data. Each participant trains locally on their own data and shares only model updates (gradients or adapter weights) with a central server, which aggregates them into a global model. For LLMs, this addresses a fundamental tension: organizations want models that benefit from diverse, domain-rich data, but privacy regulations (Section 49.4), competitive concerns, and data sovereignty laws prevent them from pooling that data in one location.

Prerequisites

This section builds on the privacy and differential privacy concepts from Section 50.1: Privacy Attacks and Differential Privacy and the distributed training fundamentals from PySpark for LLM Data Pipelines. Familiarity with fine-tuning workflows and LoRA is assumed.

Multiple cartoon robots in different countries represented by small landmarks like the Eiffel Tower, pagoda, and pyramid, all connected by a network.
Figure 50.3.1: In federated learning, the model travels to the data instead of the other way around. Each participant shares only gradient updates, never raw data.

50.3.1 Federated Learning Fundamentals

Centralized training collects all data on a central server. Federated learning inverts this: the model travels to the data, not the other way around.

Key Insight: Mental Model: The Visiting Professor

Think of FedAvg as a visiting professor who tours a hundred private libraries. Each library refuses to lend its books, but it will let the professor read in-house for an evening and then summarize what was learned. The professor visits every library, takes notes, then averages those notes back at the university. The books (raw training data) never leave the libraries; only the professor's distilled notes (gradient updates) move. The original FedAvg paper (McMahan et al., 2017) showed this works on the MNIST digit task at 100 simulated clients, reaching 99 percent test accuracy with 10x to 100x less communication than naive SGD-per-step.

Where this model breaks down: a real visiting professor never reveals what they read elsewhere, but gradients can leak training data through inversion attacks (Section 50.3.3), so the analogy understates the threat model.

The FedAvg Algorithm

McMahan et al. (2017) introduced Federated Averaging (FedAvg), the foundational FL algorithm. Each communication round runs four steps:

  1. The server sends the current global model $w_t$ to a subset of $K$ clients
  2. Each client $k$ trains the model on its local data for $E$ local epochs, producing updated weights $w_{t+1}^k$
  3. Clients send their updates $\Delta w^k = w_{t+1}^k - w_t$ back to the server
  4. The server aggregates updates using a weighted average: $w_{t+1} = w_t + \frac{1}{K} \sum_{k=1}^{K} \Delta w^k$

The weighted average is typically proportional to each client's dataset size: clients with more data contribute proportionally more to the global update.

Fun Fact

Google's Gboard keyboard uses federated learning to improve next-word prediction across billions of Android devices. Every time you type a message, your phone trains a tiny model update locally, encrypts it, and sends only the encrypted gradient to Google's servers. Google learns that people are starting to type "rizz" more often, but never sees any individual's messages. It is autocomplete training at planetary scale, powered by phones that are mostly sitting in pockets.

Key Challenges

50.3.2 Federated Fine-Tuning of LLMs

Full federated pretraining of LLMs is prohibitively expensive: the communication cost of exchanging billions of parameters each round is impractical. Instead, the dominant approach is federated fine-tuning, where a pretrained base model is adapted to domain-specific data held by multiple parties.

Federated LoRA (FFA-LoRA)

LoRA adapters are a natural fit for federated learning because they are small. Instead of exchanging 7B parameters each round, clients exchange only the low-rank adapter matrices (typically 1-10M parameters), reducing communication cost by 100-1000x.

import copy
import torch
def federated_lora_round(global_adapter, client_datasets, base_model, config):
    """One round of federated LoRA fine-tuning."""
    client_adapters = []
    client_sizes = []
    for dataset in client_datasets:
        # Each client starts from the global adapter
        local_adapter = copy.deepcopy(global_adapter)
        local_model = apply_lora(base_model, local_adapter)
        # Local training (E epochs on client data)
        optimizer = torch.optim.AdamW(local_adapter.parameters(), lr=config.lr)
        for epoch in range(config.local_epochs):
            for batch in dataset:
                loss = local_model(**batch).loss
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                client_adapters.append(local_adapter)
                client_sizes.append(len(dataset))
                # Server aggregation: weighted average of adapter weights
                total_size = sum(client_sizes)
                new_adapter = copy.deepcopy(global_adapter)
                with torch.no_grad():
                    for name, param in new_adapter.named_parameters():
                        weighted_sum = sum(
                            (size / total_size) * dict(adapter.named_parameters())[name]
                            for adapter, size in zip(client_adapters, client_sizes)
                            )
                        param.copy_(weighted_sum)
                        return new_adapter
                        # Training loop
                        global_adapter = initialize_lora(rank=16, target_modules=["q_proj", "v_proj"])
                        for round_num in range(config.num_rounds):
                            global_adapter = federated_lora_round(
                                global_adapter, client_datasets, base_model, config
                                )
Code Fragment 50.3.1a: Federated LoRA fine-tuning. Each client trains only the adapter matrices locally, and the server aggregates them via weighted averaging. Communication cost per round is proportional to adapter size (e.g., 30 MB for rank-16 on a 7B model), not the full model size.

Heterogeneous Federated Fine-Tuning

In practice, FL clients span a wide hardware range: some run A100 GPUs, others run consumer cards or CPUs. Heterogeneous FL lets each client pick adapter configuration (different LoRA ranks, for example) to match its compute budget, and the server runs rank-adaptive aggregation across the heterogeneous updates.

50.3.3 Privacy-Preserving Techniques in Federated LLM Training

Federated learning alone does not guarantee privacy. Gradient updates can be reverse-engineered to reconstruct training data, especially for text. Additional privacy mechanisms are essential. To see why these mechanisms are not optional, it helps to make the threat concrete: the very gradient a well-behaved client uploads can be inverted back into the private example that produced it.

Gradient Inversion: The Threat the Defenses Answer

The warning above ("gradients can leak training data") names an attack without showing it. The attack is gradient inversion, also called gradient leakage, formalized as Deep Leakage from Gradients (DLG) by Zhu et al. (2019). Consider an honest-but-curious server (or any party that observes one client's update). The client computed its gradient on a private example $(x, y)$ by backpropagating a loss $\mathcal{L}(f_\theta(x), y)$ through the shared model $f_\theta$, and uploaded the resulting per-example gradient $g^{*} = \nabla_\theta \mathcal{L}(f_\theta(x), y)$. The attacker knows $\theta$ (it is the global model everyone shares) and observes $g^{*}$, but not $x$ or $y$. The question is whether $(x, y)$ can be recovered from $g^{*}$ alone.

DLG answers yes by treating reconstruction as an optimization problem. The attacker initializes a dummy input $x'$ and dummy label $y'$ from random noise, runs the same forward and backward pass to obtain a dummy gradient $\nabla_\theta \mathcal{L}(f_\theta(x'), y')$, and then descends on $(x', y')$ to make that dummy gradient match the observed one. The objective is the gradient-matching loss:

$$(x'^{\,*}, y'^{\,*}) \;=\; \arg\min_{x',\, y'} \; \bigl\lVert\, \nabla_\theta \mathcal{L}\bigl(f_\theta(x'), y'\bigr) \,-\, g^{*} \,\bigr\rVert^{2}.$$

Crucially the model weights $\theta$ are frozen; the optimization variables are the synthetic data $(x', y')$ themselves. Each step of gradient descent on this objective nudges $x'$ toward an input whose gradient signature reproduces $g^{*}$, and because the gradient is a near-unique fingerprint of the example for a small batch (the regime below), the recovered $x'^{\,*}$ converges to a pixel-level or token-level copy of the client's private $x$. The refinement iDLG (Zhao et al., 2020) removes the need to optimize the label at all: for the standard cross-entropy loss, the sign of the gradient of the final (classification) layer reveals the ground-truth class directly, because only the true-label logit receives a negative gradient. Reading $y$ off the last-layer gradient sign makes the remaining reconstruction of $x$ both faster and more reliable.

Why it works, and when it fails. A gradient is a high-dimensional, smooth function of the input: for a single example the map $x \mapsto \nabla_\theta \mathcal{L}(f_\theta(x), y)$ has so many output coordinates (one per parameter, often millions) constraining so few input coordinates (the pixels or token embeddings of one example) that it is locally near-invertible. The system is heavily over-determined, so matching the gradient pins down the input almost uniquely. This immediately explains the two structural defenses. First, large batches: when a client's update is the average gradient over $B$ examples, the observed vector $\frac{1}{B}\sum_{i} \nabla_\theta \mathcal{L}(f_\theta(x_i), y_i)$ is a sum, and the attacker must disentangle $B$ unknown inputs from one averaged signature, an under-determined problem that degrades rapidly as $B$ grows. Second, aggregation: secure aggregation forces the server to observe only the sum over many clients, pushing the effective batch to thousands and destroying the per-example signal the attack depends on. The threat is therefore sharpest exactly where FL is most naive: per-example or tiny-batch updates sent in the clear.

Code Fragment 50.3.3 below implements the gradient-matching loop of the objective above so the leakage is not merely asserted. It is deliberately a skeleton: it assumes a tiny model and a single example, prints the matching loss decreasing as $x'$ converges, and omits the cosine-similarity refinements and total-variation image priors that production attacks (and the Geiping et al. 2020 "Inverting Gradients" follow-up) add for high-resolution recovery.

# Gradient inversion (DLG) skeleton: recover a private input x
# from an observed gradient g_star, with model weights frozen.
# Optimization variable is the DUMMY input x_dummy, not the model.
import torch

def gradient_inversion(model, g_star, input_shape, true_label,
                       steps=300, lr=0.1):
    # Start the dummy input from random noise; label is read off
    # the last-layer gradient sign (iDLG), so we treat it as known.
    x_dummy = torch.randn(input_shape, requires_grad=True)
    optimizer = torch.optim.Adam([x_dummy], lr=lr)
    loss_fn = torch.nn.CrossEntropyLoss()

    for step in range(steps):
        optimizer.zero_grad()
        # Forward + backward on the dummy example to get its gradient.
        pred = model(x_dummy)
        dummy_loss = loss_fn(pred, true_label)
        dummy_grad = torch.autograd.grad(
            dummy_loss, model.parameters(), create_graph=True
        )
        # Gradient-matching objective: || dummy_grad - g_star ||^2
        match = sum(((dg - gs) ** 2).sum()
                    for dg, gs in zip(dummy_grad, g_star))
        match.backward()      # backprop THROUGH the gradient, into x_dummy
        optimizer.step()
        if step % 50 == 0:
            print(f"step {step:3d}  match_loss = {match.item():.4f}")
    return x_dummy.detach()

# Representative output (loss falls as x_dummy approaches the private x):
# step   0  match_loss = 14.8271
# step  50  match_loss =  2.1043
# step 100  match_loss =  0.3389
# step 150  match_loss =  0.0412
# step 200  match_loss =  0.0036
Code Fragment 50.3.3: The DLG gradient-matching loop. The optimizer updates x_dummy (never the model) to minimize match, the squared distance between the dummy gradient and the observed g_star; create_graph=True is what lets the outer .backward() differentiate through the inner gradient. The printed match_loss decreasing toward zero corresponds to x_dummy converging to the client's private input.

With the threat demonstrated, the two defenses already named in this section read as direct countermeasures rather than abstract good practice. Secure aggregation (detailed below) attacks the observability of $g^{*}$: the server never sees an individual gradient, only the masked sum, so there is no per-client signature to invert. DP-FL attacks the information content of $g^{*}$: clipping bounds how much any one example can move the gradient and the added Gaussian noise blurs the fingerprint, so the gradient-matching minimizer no longer lands on the true $x$. The residual risk is real and must be stated plainly: secure aggregation still leaks the aggregate (a sufficiently small cohort or a colluding majority can re-expose individuals), and DP only bounds reconstruction in expectation under its $(\epsilon, \delta)$ budget, so a loose budget still permits partial leakage. This is the privacy-utility tradeoff in concrete form: tighter clipping and larger noise shrink the attacker's reconstruction fidelity but also degrade the gradient signal the global model learns from, which is why, as noted below, the tradeoff bites harder for high-dimensional LLMs than for small classifiers.

Differential Privacy with Federated Learning (DP-FL)

Combining differential privacy with FL adds formal privacy guarantees. Each client clips and adds noise to their gradients before sending them to the server:

  1. Gradient clipping: Bound the L2 norm of each client's update to a maximum value $C$, limiting any single example's influence
  2. Noise addition: Add calibrated Gaussian noise $\mathcal{N}(0, \sigma^2 C^2 I)$ to the clipped gradients
  3. Privacy accounting: Track the cumulative privacy budget $(\epsilon, \delta)$ across rounds using the moments accountant or Renyi DP

The privacy-utility tradeoff is more severe for LLMs than for smaller models: the high dimensionality means more noise is needed to achieve the same privacy guarantee, and LLMs are more sensitive to noisy gradients during fine-tuning.

Secure Aggregation

Secure aggregation protocols ensure the server can compute the aggregate update without seeing any individual client's contribution. This is achieved through cryptographic techniques (secret sharing, homomorphic encryption) that allow addition over encrypted values. Google's implementation in Gboard uses secure aggregation for federated next-word prediction on mobile devices.

Under the Hood: Secure aggregation

Secure aggregation (Bonawitz et al., 2017) lets a server learn only the SUM of client updates, never an individual one. Each pair of clients agrees on a shared random seed via Diffie-Hellman key exchange and derives a pairwise mask; client i adds the mask for every pair (i,j) where j>i and subtracts it where j<i. When the server sums all masked vectors the paired masks cancel exactly, leaving the true aggregate. A secret-sharing layer (Shamir) lets the protocol recover the masks of clients who drop out mid-round, so stragglers do not corrupt the sum. Combined with DP noise it is what makes Gboard's federated next-word prediction private.

50.3.4 Applications and Use Cases

Table 50.1.2: Secure aggregation protocols ensure the server can compute the aggregate update without seeing any individual client's contribution.
Domain Use Case Why Federated?
Healthcare Clinical NLP across hospital networks HIPAA prohibits sharing patient records; each hospital trains locally on its EHR data
Finance Fraud detection language models across banks Regulatory silos prevent data sharing; FL enables collaborative model improvement
Mobile / Edge On-device keyboard prediction User typing data is highly personal; Google's Gboard pioneered this approach
Legal Contract analysis across law firms Attorney-client privilege prevents data pooling; FL trains on distributed corpora
Multi-national Global models trained across data sovereignty boundaries GDPR, China's PIPL, and other regulations restrict cross-border data transfer
Table 50.1.1: Federated learning use cases for LLMs. The common thread is data that cannot be centralized due to legal, regulatory, or competitive constraints.

50.3.5 Frameworks and Tools

import torch
import flwr as fl
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig
# Client: each participant runs this on their local data
class LLMClient(fl.client.NumPyClient):
    def __init__(self, model_name, local_dataset):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        base = AutoModelForCausalLM.from_pretrained(model_name)
        lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
        self.model = get_peft_model(base, lora_config)
        self.dataset = local_dataset
    def get_parameters(self, config):
        """Return only LoRA adapter parameters."""
        return [
            val.cpu().numpy()
            for name, val in self.model.named_parameters()
            if "lora" in name
            ]
    def fit(self, parameters, config):
        """Train locally for E epochs, return updated adapter weights."""
        self.set_parameters(parameters)
        train_local(self.model, self.dataset, epochs=config.get("local_epochs", 3))
        return self.get_parameters(config), len(self.dataset), {}
    def set_parameters(self, parameters):
        lora_params = [
            (name, param) for name, param in self.model.named_parameters()
            if "lora" in name
            ]
        for (name, param), new_val in zip(lora_params, parameters):
            param.data = torch.tensor(new_val).to(param.device)
            # Server: aggregation strategy
            strategy = fl.server.strategy.FedAvg(
                min_fit_clients=3,
                min_available_clients=5,
                fraction_fit=0.5, # Sample 50% of available clients per round
                )
            # Launch federated training
            fl.server.start_server(
                server_address="0.0.0.0:8080",
                config=fl.server.ServerConfig(num_rounds=20),
                strategy=strategy,
                )
Code Fragment 50.3.2: Federated LoRA fine-tuning with the Flower framework. Each client trains only LoRA adapter parameters and exchanges them with the server. The FedAvg strategy aggregates adapter weights across clients.

50.3.6 Challenges and Limitations

Warning: Common Misconception

Federated learning does not guarantee privacy by default. While raw data stays local, the model updates (gradients) shared during federated training can leak information about individual training examples through gradient inversion attacks. Federated learning must be combined with secure aggregation, differential privacy, or both to provide meaningful privacy guarantees. Treating federated learning as inherently private is a dangerous oversimplification.

Key Takeaways
Research Frontier

Federated instruction tuning and federated RLHF are active research areas. FedIT (Zhang et al., 2024) showed that federating the instruction-tuning stage, where multiple organizations contribute instruction-response pairs without sharing them, produces models competitive with centrally trained ones.

Combining FL with preference optimization (Chapter 18) enables collaborative alignment without sharing sensitive preference data.

Self-Check Questions
Why is federated LoRA more practical than federated full fine-tuning for LLMs?

Full fine-tuning requires exchanging all model parameters (e.g., 28 GB for a 7B model in FP32) each communication round, which is prohibitive for bandwidth. LoRA adapters are typically 1-10M parameters (30-300 MB), reducing communication cost by 100-1000x while achieving comparable fine-tuning quality. This makes FL practical even over internet connections rather than requiring data center interconnects.

What is the non-IID problem in federated learning, and why does it affect LLMs?

Non-IID (non-independent and identically distributed) data means client datasets have different distributions: a medical institution sees clinical text, a legal firm sees contracts, a bank sees financial reports. When clients optimize on very different data distributions, their gradient updates point in different directions, causing the averaged global model to converge slowly or to a suboptimal point. For LLMs, this is especially acute because language varies enormously across domains, registers, and vocabularies.

Why is federated learning alone insufficient for privacy, and what additional mechanisms are needed?

Gradient updates can leak training data through gradient inversion attacks: an adversary can reconstruct training text from observed gradients. Additional mechanisms needed include differential privacy (adding noise to gradients before sharing), secure aggregation (cryptographic protocols that prevent the server from seeing individual updates), and gradient compression (reducing the information content of updates). These add privacy guarantees but increase computational cost and reduce model quality.

What Comes Next

With privacy attacks, differential privacy, and federated learning covered, the next section turns to Machine Unlearning: techniques for removing specific knowledge from a trained model after the fact, complementing the preventive measures presented here.

Further Reading

Foundational Papers

McMahan, B., Moore, E., Ramage, D., Hampson, S., & Aguera y Arcas, B. (2017). "Communication-Efficient Learning of Deep Networks from Decentralized Data." AISTATS 2017. arXiv:1602.05629. The original FedAvg paper that defined federated learning.
Kairouz, P., McMahan, H. B., et al. (2021). "Advances and Open Problems in Federated Learning." Foundations and Trends in ML. arXiv:1912.04977. Comprehensive survey of the federated-learning research agenda; the standard reference for system designers.
Bonawitz, K., Eichner, H., Grieskamp, W., et al. (2019). "Towards Federated Learning at Scale: System Design." SysML 2019. arXiv:1902.01046. Google's production FL system; the architecture reference for cross-device deployments.

Federated LLM Training

Zhang, Z., Yang, Y., Dai, Y., et al. (2024). "FedLLM: Communication-Efficient Federated Fine-Tuning of LLMs." arXiv:2404.06448. Federated LoRA fine-tuning that minimizes upload bandwidth; the most-cited 2024 federated-LLM paper.
Flower Labs (2024). "Flower: A Friendly Federated Learning Framework." flower.ai. The leading open-source FL framework; supports PyTorch, TensorFlow, and Hugging Face Transformers.

Privacy and Security

Dwork, C., & Roth, A. (2014). The Algorithmic Foundations of Differential Privacy. cis.upenn.edu/~aaroth/privacybook. The standard reference; FedAvg with differential privacy is the deployable baseline.