xlm.utils.hf_hub
Utilities for Hugging Face Hub integration.
repo_id_from_hf_path(path)
Extract repo_id from HF Hub path (URL or org/repo). Returns None if invalid.
download_model_weights(repo_id, revision='main', token=None)
Download model weights from Hugging Face Hub.
Tries model.safetensors, then sharded safetensors (model.safetensors.index.json
plus shard files), then pytorch_model.bin.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repo_id
|
str
|
Hugging Face repository ID (e.g., "org/model"). |
required |
revision
|
str
|
Git revision (branch, tag, or commit). Defaults to "main". |
'main'
|
token
|
str | None
|
HF token for private repos. Uses HF_HUB_KEY env if None. |
None
|
Returns:
| Type | Description |
|---|---|
str
|
Path to the downloaded weights file, or to |
str
|
weights are sharded (see :func: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If no supported weight layout exists in the repo. |
load_model_state_dict_from_file(checkpoint_path, map_location='cpu', weights_only=True, mmap=True)
Load model state dict from a checkpoint file (safetensors or pickle).
For .safetensors uses load_file directly (no metadata validation).
For model.safetensors.index.json merges all referenced shards into one dict
(high peak RAM for large models; prefer loading via load_model_weights_into_model).
For .bin/.pt uses torch.load.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
str
|
Path to model weights, index JSON, or pytorch_model.bin. |
required |
map_location
|
str
|
Device to load tensors to. |
'cpu'
|
weights_only
|
bool
|
If True, use weights_only for pickle (PyTorch >= 1.13). |
True
|
mmap
|
bool
|
If True, use mmap for loading the checkpoint. Saves CPU RAM. |
True
|
Returns: State dict for model.load_state_dict().
load_model_weights_into_model(model, checkpoint_path, map_location='cpu', strict=True, weights_only=True)
Load weights from checkpoint into model. Aligns with harness and hub_mixin.
For .safetensors uses safetensors.torch.load_model (handles tensor sharing).
For sharded safetensors (model.safetensors.index.json) loads one shard at a
time to limit peak CPU memory (avoids holding a full merged state dict).
For .bin/.pt uses model.load_state_dict(torch.load(...)).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model to load weights into. |
required |
checkpoint_path
|
str
|
Path to weights file, safetensors index, or pytorch_model.bin. |
required |
map_location
|
str
|
Device to load tensors to. |
'cpu'
|
strict
|
bool
|
Whether to enforce exact key match (safetensors) or pass to load_state_dict. |
True
|
weights_only
|
bool
|
If True, use weights_only for pickle (PyTorch >= 1.13). |
True
|