This page contains information about how to train large models using FSDP.
Setup
xlm-core trains large models with PyTorch's Fully Sharded Data Parallel (FSDP) via Lightning's FSDPStrategy. The configuration is split across three layers so that the parts that change rarely (sharding strategy) are separated from the parts that depend on the model (which class to wrap, which dtypes to use):
- Base strategy config (
trainer_strategy/fsdp.yaml, in xlm-core) — sharding strategy, CPU offload,use_orig_params, andstate_dict_type. Model-agnostic. - Experiment config — model-specific options:
auto_wrap_policy,activation_checkpointing_policy,mixed_precision. Set in your project'sconfigs/experiment/*.yaml(example below). - CLI / launcher — selects the strategy at run time with
trainer.strategy=fsdp(and composes the experiment YAML).
1. The base trainer_strategy/fsdp.yaml
See trainer_strategy/fsdp.yaml:
# @package trainer.strategy
# Model-specific options (auto_wrap_policy, mixed_precision, etc.) belong in experiment YAML.
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: FULL_SHARD
cpu_offload: false
use_orig_params: false
state_dict_type: sharded
Notes:
sharding_strategy: FULL_SHARDshards parameters, gradients, and optimizer state across all ranks (ZeRO-3 equivalent). UseSHARD_GRAD_OP(ZeRO-2) orNO_SHARD(DDP) if memory is not the bottleneck.state_dict_type: shardedwrites one shard per rank instead of consolidating to a single full state dict. This is the only practical option for 7B+ models — afullstate dict has to materialize the unsharded weights on rank 0, which is what we are trying to avoid in the first place.use_orig_params: falseis the default; flip totrueonly if you need parameter-group-aware optimizers ortorch.compileover the wrapped model.
2. The model-specific experiment YAML
Layered on top of the base strategy, the experiment YAML supplies the three knobs FSDP needs to actually shard a specific model:
# @package _global_
trainer:
strategy:
auto_wrap_policy:
_target_: xlm.utils.fsdp_grouping.make_layer_wrap_policy
_args_:
- my_package.modeling.MyDecoderLayer # dotted path to your transformer block class
activation_checkpointing_policy:
_target_: xlm.utils.fsdp_grouping.make_layer_wrap_policy
_args_:
- my_package.modeling.MyDecoderLayer
mixed_precision:
_target_: xlm.utils.fsdp_grouping.fsdp_bf16_mixed_precision
Walking through each block:
auto_wrap_policy
Tells FSDP which submodule class to treat as a sharding unit. Each instance of the class becomes its own FSDP unit — its parameters are gathered for the forward, the gradients are reduced/scattered after the backward, and its sharded shard lives on a single rank between steps. For a transformer, this should be the decoder/encoder block class. Wrapping at the block level — not the whole model and not individual nn.Linears — is what gives FSDP its memory savings without flooding the network with tiny collectives.
xlm.utils.fsdp_grouping.make_layer_wrap_policy simply imports the dotted class paths you pass and returns the set of classes that Lightning's FSDPStrategy expects. Pass multiple classes if your model mixes block types:
from xlm.utils.fsdp_grouping import make_layer_wrap_policy
policy = make_layer_wrap_policy(
"my_package.modeling.MyDecoderLayer",
"my_package.modeling.SomeOtherBlock",
)
activation_checkpointing_policy
Selects which submodules get activation checkpointing (recompute activations in the backward pass instead of storing them). For large models at long seq_len, activation memory dominates, so we re-checkpoint at the same granularity as the FSDP unit. Reusing make_layer_wrap_policy keeps the two policies aligned.
If you set auto_wrap_policy but not activation_checkpointing_policy, you get sharding without recompute — fine for smaller models but typically not enough at 7B.
mixed_precision
From fsdp_grouping.py:
def fsdp_bf16_mixed_precision():
"""Default FSDP mixed precision: bf16 params, fp32 reductions."""
import torch
from torch.distributed.fsdp import MixedPrecision
return MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
This is FSDP-native mixed precision: parameters are kept in bf16 for compute, the gradient all-reduce / reduce-scatter happens in fp32 to avoid bf16 numerical drift, and buffers (e.g. RoPE caches, attention masks) stay in fp32 because they are read but never reduced. These three dtypes are passed straight to torch.distributed.fsdp.MixedPrecision and govern how FSDP stores parameters and runs collectives.
Trainer(precision="bf16-mixed") is complementary, not redundant, to this block — it controls a different thing. Lightning's FSDPStrategy.mixed_precision_config resolves the FSDP MixedPrecision like this:
@property
def mixed_precision_config(self):
if self.mixed_precision is not None:
return self.mixed_precision # explicit strategy arg WINS
plugin = self.precision_plugin
if isinstance(plugin, FSDPPrecision):
return plugin.mixed_precision_config
return None
So once you set mixed_precision on the strategy, Trainer(precision=...) cannot override the FSDP-internal dtypes — the strategy's bf16/fp32/fp32 config is what FSDP actually applies. What Trainer(precision="bf16-mixed") still does, even when overridden, is wrap the forward in torch.autocast("cuda", dtype=torch.bfloat16) via FSDPPrecision.forward_context. With param_dtype=bf16 already, autocast is mostly a no-op (matmuls run in bf16 either way) but it gives a small amount of op-level numerical protection (e.g. cross-entropy intermediates) and matches DreamOn's reference setup, which wraps its forward in an explicit torch.autocast(bf16) block.
For a different precision regime (e.g. fp16 with a loss scaler, or pure bf16 with reduce_dtype=torch.bfloat16), define your own factory and point _target_ at it.
3. CLI invocation
The base lightning_train/config.yaml defaults to trainer_strategy: single_device. To switch on FSDP at launch time, add trainer.strategy=fsdp and compose your FSDP experiment overlay in the experiment list:
xlm \
job_type=train \
job_name=my_fsdp_run \
experiment=[my_experiment,fsdp_args] \
per_device_batch_size=1 trainer.devices=8 trainer.num_nodes=1 \
trainer.strategy=fsdp compile=false \
++trainer.precision=bf16-mixed
You may also need to drop callbacks that don't play well with sharded checkpoints during smoke runs:
~callbacks.checkpoint_monitor ~callbacks.on_exception_checkpoint
Three things to call out about this command line:
experiment=[my_experiment,fsdp_args]is Hydra list-composition — the second entry overlays on the first, so the FSDP strategy block lands inside the existingtrainer:config rather than replacing it.trainer.strategy=fsdpis the Hydra group override fortrainer_strategy. The base YAML uses# @package trainer.strategy, so its keys (sharding_strategy,cpu_offload, …) merge underneath the experiment's_target_: FSDPStrategyand the model-specific keys above.++trainer.precision=bf16-mixedis intentionally on top of the strategy's explicitmixed_precisionblock. It does not change FSDP's parameter/reduction/buffer dtypes (the explicitmixed_precisionwins; see the section above) — its only effect here is adding thetorch.autocast(bf16)wrapper around the forward, mirroring DreamOn's reference setup. Drop it if you want a strict no-autocast forward.
4. Diagnostics
When FSDP misbehaves, three things go wrong most often: the wrap policy did not match any modules (so the model is not actually sharded), the dtype config did not propagate (so memory is double what you expect), or activation memory dominates a particular phase. The FSDPDiagnosticsCallback covers all three.
callbacks:
fsdp_diagnostics:
_target_: xlm.utils.fsdp_diagnostics_callback.FSDPDiagnosticsCallback
num_logged_batches: 3
log_module_tree_top_k: 5
log_to_logger: true
What it reports (rank 0, prefixed [FSDPDiagnostics …] in the log):
- Resolved strategy settings at
setup—sharding_strategy,cpu_offload,use_orig_params,auto_wrap_policy,activation_checkpointing_policy,mixed_precision,state_dict_type,precision_plugin, andprecision_plugin.mixed_precision_config. Use this to confirm the YAML actually merged the way you expected.
You will normally see two MixedPrecision lines here when both mixed_precision (on the strategy) and Trainer(precision=...) are set, and they may print contradictory dtypes (e.g. mixed_precision=MixedPrecision(param_dtype=bf16, reduce_dtype=fp32, ...) and precision_plugin.mixed_precision_config=MixedPrecision(param_dtype=fp32, reduce_dtype=bf16, ...)). This is intentional. The first is what FSDP actually uses; the second is what Lightning would have built from Trainer(precision=...) if you had not set mixed_precision. Per the resolver in FSDPStrategy.mixed_precision_config, the explicit one wins and the plugin's value is shadowed.
- Post-wrap module tree at
on_fit_start— number ofFullyShardedDataParallelunits, number ofCheckpointWrappers, sample names, and per-ranklocal_trainable_param_MiB. Iffsdp_unitsis1or0, yourauto_wrap_policydid not match the layer class (typo, wrong dotted path, or the layer is wrapped by something opaque likenn.Sequentialyour way). Per-rank shard sizes should be roughlytotal_params / world_size. Note the count is per-layer wrappers only — PyTorch always adds one outer "root" FSDP wrap, so the actualFullyShardedDataParallelinstance count isfsdp_units + 1. - Per-phase peak GPU memory for the first
num_logged_batchesbatches —batch_start,after_backward,before_optimizer_step,batch_end. Forward-heavy peaks point at activations (more aggressive checkpointing or a smallerseq_len); optimizer-step peaks point at parameter / state shards (considercpu_offload: trueor smallerper_device_batch_size). On large-vocab models theafter_backwardpeak is often dominated by the[B, S, V]logits and their fp32 gradient, not by activations.
5. Checkpointing and resuming
FSDP checkpointing is meaningfully different from the usual single-GPU / DDP path, and the xlm-core defaults are tuned for the FSDP variant. The two things to know up front:
- A "checkpoint" with FSDP is not necessarily a single
.ckptfile — understate_dict_type: sharded(the default intrainer_strategy/fsdp.yaml) it is a directory. - Both saving and loading are collective operations: every rank must reach the same point at the same time. A failure on one rank during the save can leave the entire process group in a bad state.
Two state_dict_type modes
| Mode | Filesystem layout | Memory at save time | Use case |
|---|---|---|---|
sharded (xlm-core default) |
Directory containing one __N_M.distcp shard per rank plus a single meta.pt written by rank 0. |
Each rank writes its own shard; no rank gathers the full weights. | Training runs, especially anything that cannot fit the full model on one GPU. |
full |
Single .ckpt file (rank 0 gathers everything). |
Rank 0 must hold full unsharded weights + optimizer state in memory. Prohibitive at 7B+ in fp32 master copies. | Final export, hub upload, single-GPU eval. |
A sharded-checkpoint directory looks like this on disk:
checkpoints/last.ckpt/
├── meta.pt # rank-0 only: trainer/callback state, hyperparameters, global_step, ...
├── __0_0.distcp # rank 0 model + optimizer shard
├── __1_0.distcp # rank 1 ...
├── __2_0.distcp
└── ... (one .distcp per rank)
Lightning's heuristic for distinguishing the two formats is exactly this:
def _is_sharded_checkpoint(path):
return path.is_dir() and (path / "meta.pt").is_file()
def _is_full_checkpoint(path):
return path.is_file()
So last.ckpt may be a file (full) or a directory (sharded) depending on the strategy. Anything that uses os.path.isfile to detect a checkpoint will silently miss sharded ones — this matters for the auto-resume path below.
Saving is a collective; the on-exception save can hang
FSDPStrategy.save_checkpoint calls _distributed_checkpoint_save (sharded) or FSDP.summon_full_params + torch.save (full). Both require every rank to enter the call together. The practical consequences:
- All ranks must successfully complete validation / training-step before the save fires. If one rank errors during the train or validation hook, the others will block on the next collective. The save then either hangs until the NCCL watchdog timeout or throws
DistBackendError— the failure mode the OOM/checkpoint hang debugging session in this repo ran into during validation-triggered checkpointing. OnExceptionCheckpointis risky under FSDP. When Lightning's exception path triggers a secondtrainer.save_checkpoint(...)after a failure, it issues another full set of collectives over a process group that may already be poisoned. The typical symptom is the run flushing one rank-0 INFO line ("Saving checkpoint on exception ...") and then hanging silently until the heartbeat timer fires. Dropcallbacks.on_exception_checkpointfor FSDP runs unless you have a specific reason to keep it; you can do this on the CLI:
~callbacks.on_exception_checkpoint
- Sharded
ModelCheckpointwithsave_top_k > 0is fine, but expect each save to take longer than the equivalent DDP save: it's a write barrier across the world. Settingevery_n_train_stepslow (e.g. every 100 steps at 7B) can dominate wall-clock time.
Loading happens after FSDP wrap, not before
Two non-obvious flags on FSDPStrategy change the lifecycle of resume-from-checkpoint:
restore_checkpoint_after_setup = True— Lightning loads the checkpoint after_setup_modelhas wrapped the module in FSDP. This is the opposite of the standard non-FSDP flow, where weights are loaded into the barenn.Modulefirst and the strategy then handles distribution. The wrap must therefore succeed before any weights are read; if yourauto_wrap_policyis broken, the load also fails. This also meansskip_init_weights: trueis safe to use with FSDP resume — the random-init step is skipped, and FSDP loads the real weights into already-shardedFlatParameters.lightning_restore_optimizer = False— Lightning's normal optimizer-state-restore code path is bypassed under FSDP; FSDP's ownoptim_state_dict_to_loadre-flattens the saved optimizer state into the per-rankFlatParameterlayout. The upshot is that a checkpoint saved at world size W can typically be loaded at world size W' (thetorch.distributed.checkpointAPI handles re-sharding forshardedcheckpoints), but the model architecture and FSDP wrap policy must match between save and load.
For sharded checkpoints, load_checkpoint calls _distributed_checkpoint_load (a collective), then loads optimizer state per optimizer via FSDP.optim_state_dict_to_load, then torch.loads the rank-0 meta.pt for the trainer/callback metadata. All of that happens after the wrap, so the very first thing it requires is a healthy NCCL process group.
Resuming and extracting weights in xlm-core
Training resume (lightning_train.py) supports the same checkpoint paths Lightning’s trainer accepts:
- A regular
.ckptfile, or - A sharded checkpoint directory (e.g.
last.ckpt/oron_exception.ckpt/) that contains at least one*.distcpshard (see layout above).
Resolution and validation live in xlm.utils.checkpoint_paths: explicit resume_checkpoint_path must exist and be either a file or a distcp shard directory; auto-pickup checks checkpointing_dir/on_exception.ckpt first, then last.ckpt, using the same rules.
Exporting model-only weights (Hub / inference):
Use xlm.utils.consolidate_model_checkpoint.consolidate_model_checkpoint on a sharded checkpoint folder that contains *.distcp shards and meta.pt (the same layout Lightning uses for state_dict_type: sharded). It loads the distributed checkpoint (PyTorch ≥ 2.3), applies Lightning’s checkpoint reformatting, strips model. / _orig_mod. prefixes, and writes model-only weights:
- Single file (default): pass a target path (e.g.
…/model.safetensors). Suitable formodel_only_checkpoint_pathand small checkpoints. - HF multi-shard safetensors: pass
max_shard_size=(e.g."5GB") and an output directory; producesmodel.safetensors.index.jsonplusmodel-….safetensorsshards compatible with download_model_weights / load_model_weights_into_model. For local inference, setmodel_only_checkpoint_pathto the index JSON file path, not just the folder.
push_to_hub still loads the module (e.g. via model_only_checkpoint_path pointing at that .safetensors or index file) and then uploads from memory via PyTorchModelHubMixin as a single model.safetensors; it does not upload a pre-built multi-shard folder. Models above the Hub single-file limit need a multi-shard folder uploaded separately (for example via HfApi.upload_folder after consolidating with max_shard_size).
- When
checkpoint_pathpoints at a sharded directory,extract_checkpointruns consolidate_model_checkpoint to write model-only.safetensors(optional HF multi-shard viapost_training.max_shard_size). You must passapply_ema=false; EMA is not applied on this path. See extract-checkpoint.md.
The recommended pattern at 7B is: train with state_dict_type: sharded, then consolidate offline to safetensors (single or sharded) on a machine with enough CPU RAM for the full checkpoint (Lightning’s loader materializes the full state in memory).
6. Gotchas
A few things that have bitten us in practice:
- Norm clipping and grad-norm logging under FSDP. With
FSDPStrategy, Lightning's built-ingradient_clip_algorithm: normroutes throughFSDPPrecision, which raisesMisconfigurationException—torch.nn.utils.clip_grad_norm_()is wrong for shardedFlatParametergradients. Separately,Harness.on_before_optimizer_stepuseslightning.pytorch.utilities.grad_norm(self, ...), which only sums gradients visible on that rank; under FSDP that is a local shard norm, not the true global L2 norm (so W&B can show huge, nearly flat curves that do not match training stability).
Fix: subclass Harness, detect the FSDP root on self.trainer.strategy.model, and override two hooks (configure_gradient_clipping and on_before_optimizer_step). Set lightning_module._target_ to your subclass and trainer.gradient_clip_algorithm: norm.
Example pattern:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def configure_gradient_clipping(
self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None
):
if gradient_clip_val is None:
return
root = self.trainer.strategy.model
if isinstance(root, FSDP):
root.clip_grad_norm_(max_norm=float(gradient_clip_val), norm_type=2.0)
return
return super().configure_gradient_clipping(
optimizer,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
)
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
def on_before_optimizer_step(self, optimizer):
root = self.trainer.strategy.model
if not isinstance(root, FSDP):
return super().on_before_optimizer_step(optimizer)
local_sq = torch.zeros((), device=self.device, dtype=torch.float32)
for p in root.parameters():
if p.grad is not None:
local_sq += p.grad.detach().float().pow(2).sum()
if dist.is_available() and dist.is_initialized():
dist.all_reduce(local_sq, op=dist.ReduceOp.SUM)
global_norm = local_sq.sqrt()
self.log(
"Total gradient (norm)",
global_norm,
on_step=True,
on_epoch=False,
prog_bar=False,
sync_dist=False,
rank_zero_only=True,
logger=True,
add_dataloader_idx=False,
)
With these overrides, trainer.gradient_clip_algorithm: norm is the right choice; gradient_clip_algorithm: value remains fine for non-FSDP experiments (e.g. single-GPU debug) that still use base Harness.
-
Don't double-register the model on the predictor. In
Harness.instantiate_predictor, the model is attached to the predictor viaobject.__setattr__(self.predictor, "model", self.model)rather than plain=. A normal assignment would register the samenn.Moduleas a submodule of both the harness and the predictor, and FSDP would walk thoseFlatParameterstwice — roughly doubling GPU memory usage. If you write your own predictor, copy this pattern. -
trainer.precisiondoes not override the strategy'smixed_precision. A common worry is that setting bothmixed_precisiononFSDPStrategyandTrainer(precision="bf16-mixed")will conflict or double-cast. It does not. The two control different things: the strategy'smixed_precisionis what FSDP actually uses for parameter storage and collectives (FSDPStrategy.mixed_precision_configshort-circuits to it before consulting the precision plugin), whileTrainer(precision="bf16-mixed")only adds thetorch.autocast(bf16)wrapper around the forward viaFSDPPrecision.forward_context. Picking between them:Trainer(precision="bf16-mixed")(default for new FSDP runs in xlm-core): autocast on, matches DreamOn's reference setup, gives op-level numerical protection (e.g. cross-entropy upcasts intermediates to fp32 inside autocast). Near-no-op cost when params are already bf16.Trainer(precision="32-true")(the Lightning default): no autocast. Slightly more memory-friendly because there are no fp32 intermediates from autocast, but loses the numerical-stability cushion. Pick this only if you have a reason.
-
Per-rank logs.
FSDPDiagnosticsCallback'son_fit_startlog line is per-rank on purpose — every rank prints its ownlocal_trainable_param_MiBso you can spot uneven shards. Thesetup-time strategy dump and per-phase memory logs are rank-0 only. -
Loading large checkpoints. Use
skip_init_weights: truewhen resuming training or loading model-only weights beforetrainer.fitso Lightning does not spend time on a full random init that is immediately overwritten. For dtype, rely onTrainer/ strategy precision and FSDPmixed_precision;init_dtypeis only read byload_model_for_inference(eval / Hub inference), notlightning_train.py. See the LLM eval notes.