xlm.harness
LRSchedulerWithConfig
Bases: TypedDict
We follow the same structure as the one in LightningModule.
lr_scheduler_config = {
# REQUIRED: The lr_scheduler instance
"scheduler": lr_scheduler,
# The unit of the lr_scheduler's step size, could also be 'step'.
# 'epoch' updates the lr_scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# lr_scheduler.step(). 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like ReduceLROnPlateau
"monitor": "val_loss",
# If set to True, will enforce that the value specified 'monitor'
# is available when the lr_scheduler is updated, thus stopping
# training if not found. If set to False, it will only produce a warning
"strict": True,
}
Predictor
Bases: Generic[T_in, T_out_pred], Protocol
to_dict(batch, preds, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Create json lines from the predictions batch.
PredictorHistoryMixin
Mixin class for adding history tracking to predictors.
This mixin provides generic history tracking capabilities that can be used by any predictor that implements iterative generation. History is stored as a list of tuples: (decoded_text, confidence_score, step_number).
Usage
class MyPredictor(torch.nn.Module, PredictorHistoryMixin, Predictor[...]): def init(self, ..., return_history: bool = False): super().init() self.init_history(return_history=return_history) ...
def predict(self, batch):
history = self.create_history(batch_size)
for step in range(steps):
# ... generation logic ...
history = self.update_history_from_state(history, state, step)
return {"text": ..., "history": history}
init_history(return_history=False, decode_fn=None)
Initialize history tracking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
return_history
|
bool
|
Whether to track history during generation. |
False
|
decode_fn
|
Optional[Callable]
|
Optional custom decode function. If None, will use self.decode. |
None
|
create_history(batch_size)
Create empty history for a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of sequences in the batch. |
required |
Returns:
| Type | Description |
|---|---|
List[List[Tuple[str, float, int]]]
|
Empty history list for each batch element. |
update_history_from_state(history, state, step, confidence_key='confidence', active_mask_key=None)
Update history from a state dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
history
|
List[List[Tuple[str, float, int]]]
|
Current history list. |
required |
state
|
Dict[str, Any]
|
Dictionary containing the current state (must be decodable). |
required |
step
|
int
|
Current step number. |
required |
confidence_key
|
str
|
Key in state dict for confidence values (default: "confidence"). |
'confidence'
|
active_mask_key
|
Optional[str]
|
Optional key for mask indicating which samples are still active. |
None
|
Returns:
| Type | Description |
|---|---|
List[List[Tuple[str, float, int]]]
|
Updated history. |
update_history_explicit(history, texts, confidences, step, active_mask=None)
Update history with explicit values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
history
|
List[List[Tuple[str, float, int]]]
|
Current history list. |
required |
texts
|
List[str]
|
Decoded text for each batch element. |
required |
confidences
|
Union[List[float], Tensor]
|
Confidence/score for each batch element. |
required |
step
|
int
|
Current step number. |
required |
active_mask
|
Optional[Union[List[bool], Tensor]]
|
Optional mask indicating which samples are still active. |
None
|
Returns:
| Type | Description |
|---|---|
List[List[Tuple[str, float, int]]]
|
Updated history. |
format_history_for_output(history, round_precision=4)
Format history for output in to_dict methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
history
|
List[List[Tuple[str, float, int]]]
|
Raw history list. |
required |
round_precision
|
int
|
Number of decimal places to round confidence values. |
4
|
Returns:
| Type | Description |
|---|---|
List[List[List[Any]]]
|
Formatted history with rounded confidence values. |
Harness
Bases: LightningModule, PyTorchModelHubMixin
Main module that provides the scaffolding for the codebase.
tokenizer
instance-attribute
Task Metrics usually consist of two types of metrics: 1. diagnostic metrics: These are typically different for different models as well as different tasks. 2. reported metrics: These are the same for all the models but different for different tasks. What we want too do is avoid a full blown (task x model) setup whenever we can but provide it as a last resort. The best case scenario is complete decopling. This happens when all the models adhere to the same output signature. But this never works for diagnostic metrics. In some cases, different tasks can share base metrics of both types. In these cases, we can use inheritance to avoid some code duplication. We would still have (task x model) number of classes though.
__init__(cfg, tokenizer=None, datamodule=None, write_per_sample_metrics=False, **kwargs)
Initialize the Harness module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
DictConfig
|
Configuration dictionary. |
required |
tokenizer
|
Optional[Tokenizer]
|
Optional tokenizer instance. |
None
|
datamodule
|
Optional[BaseDataModule]
|
Optional datamodule instance. |
None
|
write_per_sample_metrics
|
bool
|
Whether to write per-sample metrics. |
False
|
**kwargs
|
Any
|
Additional keyword arguments. |
{}
|
setup_metrics(cfg)
Attache metrics as modules
setup_post_hoc_evaluator(cfg)
Setup post-hoc evaluator. Can be use for tasks like molecule generation.
The post-hoc evaluator computes metrics on logged predictions at epoch end, enabling global metric computation (e.g., diversity on full generated set).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
DictConfig
|
Configuration dictionary |
required |
create_lr_scheduler(optimizer, name, num_warmup_steps=None, fraction_warmup_steps=None, num_training_steps=None, interval='step', frequency=1, monitor='train_loss', strict=True, **kwargs)
staticmethod
Creates a learning rate noise_schedule with the given configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Huggingface name of the learning rate noise_schedule. https://huggingface.co/docs/transformers/en/main_classes/optimizer_schedules#transformers.get_scheduler |
required |
optimizer
|
Optimizer
|
The optimizer to use with the noise_schedule |
required |
num_training_steps
|
Optional[int]
|
The total number of training steps. |
None
|
num_warmup_steps
|
Optional[int]
|
The number of warmup steps. |
None
|
fraction_warmup_steps
|
Optional[float]
|
The fraction of training steps to use for warmup. |
None
|
interval
|
Literal['step', 'epoch']
|
The interval at which to update the learning rate. |
'step'
|
frequency
|
int
|
The frequency of the learning rate updates. |
1
|
monitor
|
Optional[str]
|
The metric to monitor for the learning rate noise_schedule. |
'train_loss'
|
strict
|
bool
|
Whether to strictly follow the learning rate schedule. |
True
|
**kwargs
|
Any
|
Additional keyword arguments to pass to the learning rate noise_schedule. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
LRSchedulerWithConfig |
LRSchedulerWithConfig
|
The configured learning rate scheduler. |
prepare_batch_for_prediction(batch)
We need this for some tasks even if we have task sepecific collator, mainly because we want to clone some elements of the batch useful for computing metrics. TODO: Get rid of this method by cloning in the collator itself.
compute_loss(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Computes loss based on the dataloader name.
For 'lm', the loss function is applied. For 'prediction', the predictor's predict_step is used.
compute_post_hoc_metrics(split, dataloader_name, epoch, step, update_logged_predictions=True)
Compute post-hoc metrics on logged predictions.
Similar to compute_generative_perplexity, but for arbitrary post-hoc metrics. Loads predictions from jsonl, computes per-sample and global metrics, and logs aggregated results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
split
|
Literal['train', 'val', 'test', 'predict']
|
train/val/test/predict |
required |
dataloader_name
|
str
|
Name of the dataloader |
required |
epoch
|
int
|
Current epoch |
required |
step
|
int
|
Current step |
required |
update_logged_predictions
|
bool
|
If True, update predictions jsonl with per-sample metrics |
True
|
Returns:
| Type | Description |
|---|---|
Optional[Dict[str, Any]]
|
Dictionary of aggregated metrics, or None if no evaluator |
extract_model_weights()
Extract current model state dict.
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Model state dict (self.model.state_dict()) |
save_model_weights(path, overwrite=False)
Save current model weights to local file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path]
|
Path to save the model weights |
required |
overwrite
|
bool
|
Whether to overwrite existing file |
False
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If file exists and overwrite is False |
load_model_weights(path, strict=True)
Load model weights from local file into self.model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path]
|
Path to the model weights file |
required |
strict
|
bool
|
Whether to strictly enforce that the keys match |
True
|
load_model_from_hub(repo_id, revision=None, cache_dir=None, force_download=False, token=None, strict=True, **kwargs)
Download and load model weights from HuggingFace Hub into self.model.
This method downloads the model weights from the hub and loads them into the existing model. It does NOT reconstruct a new Harness instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
repo_id
|
str
|
HuggingFace Hub repository ID (e.g., "username/model") |
required |
revision
|
Optional[str]
|
Git revision (branch, tag, or commit) |
None
|
cache_dir
|
Optional[Union[str, Path]]
|
Directory to cache downloaded files |
None
|
force_download
|
bool
|
Force re-download even if cached |
False
|
token
|
Optional[Union[str, bool]]
|
HuggingFace Hub token for private repos |
None
|
strict
|
bool
|
Whether to strictly enforce that the keys match |
True
|
**kwargs
|
Additional arguments for hf_hub_download |
{}
|
from_checkpoint(checkpoint_path, cfg=None, tokenizer=None, datamodule=None, apply_ema=False, map_location='cpu', **kwargs)
classmethod
Load Harness from Lightning checkpoint with optional EMA application.
This is the ONLY method that can apply EMA weights, as it has direct access to the checkpoint file containing the EMA state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
Union[str, Path]
|
Path to the Lightning checkpoint file |
required |
cfg
|
Optional[DictConfig]
|
Optional config to override checkpoint config |
None
|
tokenizer
|
Optional[Tokenizer]
|
Optional tokenizer instance |
None
|
datamodule
|
Optional[BaseDataModule]
|
Optional datamodule instance |
None
|
apply_ema
|
bool
|
Whether to apply EMA weights from checkpoint |
False
|
map_location
|
str
|
Device to load checkpoint to |
'cpu'
|
**kwargs
|
Additional arguments for load_from_checkpoint |
{}
|
Returns:
| Type | Description |
|---|---|
Harness
|
Harness instance with loaded weights (and EMA applied if requested) |
Example
Load with EMA weights applied
harness = Harness.from_checkpoint( "checkpoint.ckpt", apply_ema=True, cfg=cfg, tokenizer=tokenizer, datamodule=datamodule )