Skip to content

xlm.utils.ema

EMACallback

Bases: Callback

__init__(decay, use_num_updates=True, apply_ema_at_train_end=True)

decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. apply_ema_at_train_end: If True, applies EMA weights to the model at the end of training, so the final checkpoint contains EMA-averaged weights in the model parameters.

on_train_end(trainer, pl_module)

Apply EMA weights to the model at the end of training.

This makes the final checkpoint contain EMA-averaged weights directly in the model parameters, eliminating the need for manual EMA application during checkpoint extraction or inference.

Note: After this is called, the model's parameters will contain EMA-averaged weights, and the original training weights will be lost (though they're still stored in collected_params if store() was called).