Skip to content

xlm.utils.fsdp_diagnostics_callback

Diagnostics callback for FSDP runs.

Reports, on rank 0: 1. Resolved FSDP strategy settings (sharding, mixed precision, policies). 2. Post-wrap module-tree statistics (FSDP units, CheckpointWrappers, top names). 3. Per-phase peak GPU memory (forward / backward / optimizer) for the first num_logged_batches training steps.

These three signals together let us distinguish OOM root causes
  • if (1) shows the wrong policy / dtype, the config is not right, eg. YAML did not merge correctly, etc.;
  • if (2) shows only one FSDP unit / no CheckpointWrappers, the auto-wrap or activation-checkpointing policy did not fire on the target layer class (so the model is effectively un-sharded or fully materialized);
  • if (3) shows the peak in forward, activation memory dominates; if in optimizer, parameter / state shards dominate.

FSDPDiagnosticsCallback

Bases: Callback

Lightning Callback that surfaces FSDP wrap and per-phase memory stats.

__init__(num_logged_batches=3, log_module_tree_top_k=5, log_to_logger=True)

Parameters:

Name Type Description Default
num_logged_batches int

How many of the first training batches to instrument with peak-memory measurements. Set to 0 to disable per-batch logging.

3
log_module_tree_top_k int

How many sample module names of each kind (FSDP unit, CheckpointWrapper) to print.

5
log_to_logger bool

If True and a Trainer logger is configured, also push memory metrics through trainer.logger.log_metrics.

True