xlm.utils.fsdp_diagnostics_callback
Diagnostics callback for FSDP runs.
Reports, on rank 0:
1. Resolved FSDP strategy settings (sharding, mixed precision, policies).
2. Post-wrap module-tree statistics (FSDP units, CheckpointWrappers, top names).
3. Per-phase peak GPU memory (forward / backward / optimizer) for the first
num_logged_batches training steps.
These three signals together let us distinguish OOM root causes
- if (1) shows the wrong policy / dtype, the config is not right, eg. YAML did not merge correctly, etc.;
- if (2) shows only one FSDP unit / no CheckpointWrappers, the auto-wrap or activation-checkpointing policy did not fire on the target layer class (so the model is effectively un-sharded or fully materialized);
- if (3) shows the peak in forward, activation memory dominates; if in optimizer, parameter / state shards dominate.
FSDPDiagnosticsCallback
Bases: Callback
Lightning Callback that surfaces FSDP wrap and per-phase memory stats.
__init__(num_logged_batches=3, log_module_tree_top_k=5, log_to_logger=True)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_logged_batches
|
int
|
How many of the first training batches to instrument with peak-memory measurements. Set to 0 to disable per-batch logging. |
3
|
log_module_tree_top_k
|
int
|
How many sample module names of each kind (FSDP unit, CheckpointWrapper) to print. |
5
|
log_to_logger
|
bool
|
If True and a Trainer logger is configured, also
push memory metrics through |
True
|