Skip to content

xlm.utils.hf_hub

Utilities for Hugging Face Hub integration.

repo_id_from_hf_path(path)

Extract repo_id from HF Hub path (URL or org/repo). Returns None if invalid.

download_model_weights(repo_id, revision='main', token=None)

Download model weights from Hugging Face Hub.

Tries model.safetensors, then sharded safetensors (model.safetensors.index.json plus shard files), then pytorch_model.bin.

Parameters:

Name Type Description Default
repo_id str

Hugging Face repository ID (e.g., "org/model").

required
revision str

Git revision (branch, tag, or commit). Defaults to "main".

'main'
token str | None

HF token for private repos. Uses HF_HUB_KEY env if None.

None

Returns:

Type Description
str

Path to the downloaded weights file, or to model.safetensors.index.json when

str

weights are sharded (see :func:load_model_weights_into_model).

Raises:

Type Description
ValueError

If no supported weight layout exists in the repo.

load_model_state_dict_from_file(checkpoint_path, map_location='cpu', weights_only=True, mmap=True)

Load model state dict from a checkpoint file (safetensors or pickle).

For .safetensors uses load_file directly (no metadata validation). For model.safetensors.index.json merges all referenced shards into one dict (high peak RAM for large models; prefer loading via load_model_weights_into_model). For .bin/.pt uses torch.load.

Parameters:

Name Type Description Default
checkpoint_path str

Path to model weights, index JSON, or pytorch_model.bin.

required
map_location str

Device to load tensors to.

'cpu'
weights_only bool

If True, use weights_only for pickle (PyTorch >= 1.13).

True
mmap bool

If True, use mmap for loading the checkpoint. Saves CPU RAM.

True

Returns: State dict for model.load_state_dict().

load_model_weights_into_model(model, checkpoint_path, map_location='cpu', strict=True, weights_only=True)

Load weights from checkpoint into model. Aligns with harness and hub_mixin.

For .safetensors uses safetensors.torch.load_model (handles tensor sharing). For sharded safetensors (model.safetensors.index.json) loads one shard at a time to limit peak CPU memory (avoids holding a full merged state dict). For .bin/.pt uses model.load_state_dict(torch.load(...)).

Parameters:

Name Type Description Default
model Module

The model to load weights into.

required
checkpoint_path str

Path to weights file, safetensors index, or pytorch_model.bin.

required
map_location str

Device to load tensors to.

'cpu'
strict bool

Whether to enforce exact key match (safetensors) or pass to load_state_dict.

True
weights_only bool

If True, use weights_only for pickle (PyTorch >= 1.13).

True