xlm.utils.model_loading
Unified model loading for inference across commands.
This module provides a single, consistent interface for loading models for inference tasks (generation, evaluation, push to hub, demos).
load_model_for_inference(cfg, datamodule, tokenizer, *, config_prefix, manual_ema_restore=False, move_to_device=None, set_eval_mode=False, enable_hub_support=True, checkpoint_fallback_dir=None, allow_random_init=False)
Load and prepare a model for inference tasks.
This function provides a unified interface for loading models across different commands (generate, eval, push_to_hub, cli_demo). It supports loading from: - Full Lightning checkpoints (includes optimizer state, etc.) - Model-only checkpoints (just model weights) - Hugging Face Hub repositories - Fallback to best.ckpt or last.ckpt
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
DictConfig
|
Hydra config containing model and checkpoint configuration. |
required |
datamodule
|
Any
|
Datamodule instance for model instantiation. |
required |
tokenizer
|
Any
|
Tokenizer instance for model instantiation. |
required |
config_prefix
|
str
|
Prefix for config keys. Examples: - "generation" looks for cfg.generation.ckpt_path - "eval" looks for cfg.eval.checkpoint_path - "" (empty) looks for cfg.hub_checkpoint_path (top-level) |
required |
manual_ema_restore
|
bool
|
If True, pass manual_ema_restore=True to model loading. Used when you need to manually control EMA weight restoration. |
False
|
move_to_device
|
Optional[str]
|
Device to move model to ("cuda", "cpu", or None). If None, model stays on default device (trainer handles this for eval). |
None
|
set_eval_mode
|
bool
|
If True, call model.eval() after loading. Set to False when using Lightning Trainer (trainer handles this). |
False
|
enable_hub_support
|
bool
|
If True, support loading from hub.repo_id. Set to False for commands with different hub key structures. |
True
|
checkpoint_fallback_dir
|
Optional[str]
|
Directory to search for best.ckpt/last.ckpt if no explicit checkpoint is provided. Used by eval command. |
None
|
allow_random_init
|
bool
|
If True, allow instantiating a randomly initialized model when no checkpoint is found. Default False for safety. |
False
|
Returns:
| Type | Description |
|---|---|
Harness
|
Tuple of (lightning_module, checkpoint_path) where: |
Optional[str]
|
|
tuple[Harness, Optional[str]]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If no checkpoint is found and allow_random_init=False. |
ValueError
|
If checkpoint file doesn't exist. |
Examples:
Generation: Load from checkpoint with HF Hub support
>>> module, _ = load_model_for_inference(
... cfg, datamodule, tokenizer,
... config_prefix="generation",
... manual_ema_restore=True,
... move_to_device="cuda",
... set_eval_mode=True,
... enable_hub_support=True,
... )
Evaluation: Return checkpoint path for trainer
>>> module, ckpt_path = load_model_for_inference(
... cfg, datamodule, tokenizer,
... config_prefix="eval",
... checkpoint_fallback_dir=cfg.checkpointing_dir,
... )
>>> trainer.validate(module, datamodule, ckpt_path=ckpt_path)
Push to Hub: Top-level config keys
>>> module, _ = load_model_for_inference(
... cfg, datamodule, tokenizer,
... config_prefix="",
... manual_ema_restore=True,
... move_to_device="cuda",
... set_eval_mode=True,
... enable_hub_support=False,
... )