Skip to content

xlm.utils.model_state_dict

Extract plain model weights from Lightning module state_dict payloads.

extract_model_only_from_lightning_state_dict(state_dict)

Strip model. / _orig_mod. prefixes from a LightningModule state_dict.

Keeps only tensor values suitable for safetensors / load_state_dict.

tensor_state_dict_from_checkpoint_dict(checkpoint)

Load a consolidated Lightning torch.save dict and return model-only tensors.

Expects a top-level state_dict key (standard Lightning checkpoint after consolidation).