Skip to content

xlm.utils.model_loading

Unified model loading for inference across commands.

This module provides a single, consistent interface for loading models for inference tasks (generation, evaluation, push to hub, demos).

load_model_for_inference(cfg, datamodule, tokenizer, *, config_prefix, manual_ema_restore=False, move_to_device=None, set_eval_mode=False, enable_hub_support=True, checkpoint_fallback_dir=None, allow_random_init=False)

Load and prepare a model for inference tasks.

This function provides a unified interface for loading models across different commands (generate, eval, push_to_hub, cli_demo). It supports loading from: - Full Lightning checkpoints (includes optimizer state, etc.) - Model-only checkpoints (just model weights) - Hugging Face Hub repositories - Fallback to best.ckpt or last.ckpt

Parameters:

Name Type Description Default
cfg DictConfig

Hydra config containing model and checkpoint configuration.

required
datamodule Any

Datamodule instance for model instantiation.

required
tokenizer Any

Tokenizer instance for model instantiation.

required
config_prefix str

Prefix for config keys. Examples: - "generation" looks for cfg.generation.ckpt_path - "eval" looks for cfg.eval.checkpoint_path - "" (empty) looks for cfg.hub_checkpoint_path (top-level)

required
manual_ema_restore bool

If True, pass manual_ema_restore=True to model loading. Used when you need to manually control EMA weight restoration.

False
move_to_device Optional[str]

Device to move model to ("cuda", "cpu", or None). If None, model stays on default device (trainer handles this for eval).

None
set_eval_mode bool

If True, call model.eval() after loading. Set to False when using Lightning Trainer (trainer handles this).

False
enable_hub_support bool

If True, support loading from hub.repo_id. Set to False for commands with different hub key structures.

True
checkpoint_fallback_dir Optional[str]

Directory to search for best.ckpt/last.ckpt if no explicit checkpoint is provided. Used by eval command.

None
allow_random_init bool

If True, allow instantiating a randomly initialized model when no checkpoint is found. Default False for safety.

False

Returns:

Type Description
Harness

Tuple of (lightning_module, checkpoint_path) where:

Optional[str]
  • lightning_module: The loaded Harness model ready for inference
tuple[Harness, Optional[str]]
  • checkpoint_path: Path to the full checkpoint (or None if using model-only checkpoint or random init). Used by eval to pass to trainer.

Raises:

Type Description
ValueError

If no checkpoint is found and allow_random_init=False.

ValueError

If checkpoint file doesn't exist.

Examples:

Generation: Load from checkpoint with HF Hub support

>>> module, _ = load_model_for_inference(
...     cfg, datamodule, tokenizer,
...     config_prefix="generation",
...     manual_ema_restore=True,
...     move_to_device="cuda",
...     set_eval_mode=True,
...     enable_hub_support=True,
... )

Evaluation: Return checkpoint path for trainer

>>> module, ckpt_path = load_model_for_inference(
...     cfg, datamodule, tokenizer,
...     config_prefix="eval",
...     checkpoint_fallback_dir=cfg.checkpointing_dir,
... )
>>> trainer.validate(module, datamodule, ckpt_path=ckpt_path)

Push to Hub: Top-level config keys

>>> module, _ = load_model_for_inference(
...     cfg, datamodule, tokenizer,
...     config_prefix="",
...     manual_ema_restore=True,
...     move_to_device="cuda",
...     set_eval_mode=True,
...     enable_hub_support=False,
... )