xlm.utils.ema
EMACallback
Bases: Callback
__init__(decay, use_num_updates=True, apply_ema_at_train_end=True)
decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. apply_ema_at_train_end: If True, applies EMA weights to the model at the end of training, so the final checkpoint contains EMA-averaged weights in the model parameters.
on_train_end(trainer, pl_module)
Apply EMA weights to the model at the end of training.
This makes the final checkpoint contain EMA-averaged weights directly in the model parameters, eliminating the need for manual EMA application during checkpoint extraction or inference.
Note: After this is called, the model's parameters will contain EMA-averaged weights, and the original training weights will be lost (though they're still stored in collected_params if store() was called).