Skip to content

xlm.harness

LRSchedulerWithConfig

Bases: TypedDict

We follow the same structure as the one in LightningModule. lr_scheduler_config = { # REQUIRED: The lr_scheduler instance "scheduler": lr_scheduler, # The unit of the lr_scheduler's step size, could also be 'step'. # 'epoch' updates the lr_scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # lr_scheduler.step(). 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to to monitor for schedulers like ReduceLROnPlateau "monitor": "val_loss", # If set to True, will enforce that the value specified 'monitor' # is available when the lr_scheduler is updated, thus stopping # training if not found. If set to False, it will only produce a warning "strict": True, }

Predictor

Bases: Generic[T_in, T_out_pred], Protocol

to_dict(batch, preds, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Create json lines from the predictions batch.

PredictorHistoryMixin

Mixin class for adding history tracking to predictors.

This mixin provides generic history tracking capabilities that can be used by any predictor that implements iterative generation. History is stored as a list of tuples: (decoded_text, confidence_score, step_number).

Usage

class MyPredictor(torch.nn.Module, PredictorHistoryMixin, Predictor[...]): def init(self, ..., return_history: bool = False): super().init() self.init_history(return_history=return_history) ...

def predict(self, batch):
    history = self.create_history(batch_size)
    for step in range(steps):
        # ... generation logic ...
        history = self.update_history_from_state(history, state, step)
    return {"text": ..., "history": history}

init_history(return_history=False, decode_fn=None)

Initialize history tracking.

Parameters:

Name Type Description Default
return_history bool

Whether to track history during generation.

False
decode_fn Optional[Callable]

Optional custom decode function. If None, will use self.decode.

None

create_history(batch_size)

Create empty history for a batch.

Parameters:

Name Type Description Default
batch_size int

Number of sequences in the batch.

required

Returns:

Type Description
List[List[Tuple[str, float, int]]]

Empty history list for each batch element.

update_history_from_state(history, state, step, confidence_key='confidence', active_mask_key=None)

Update history from a state dictionary.

Parameters:

Name Type Description Default
history List[List[Tuple[str, float, int]]]

Current history list.

required
state Dict[str, Any]

Dictionary containing the current state (must be decodable).

required
step int

Current step number.

required
confidence_key str

Key in state dict for confidence values (default: "confidence").

'confidence'
active_mask_key Optional[str]

Optional key for mask indicating which samples are still active.

None

Returns:

Type Description
List[List[Tuple[str, float, int]]]

Updated history.

update_history_explicit(history, texts, confidences, step, active_mask=None)

Update history with explicit values.

Parameters:

Name Type Description Default
history List[List[Tuple[str, float, int]]]

Current history list.

required
texts List[str]

Decoded text for each batch element.

required
confidences Union[List[float], Tensor]

Confidence/score for each batch element.

required
step int

Current step number.

required
active_mask Optional[Union[List[bool], Tensor]]

Optional mask indicating which samples are still active.

None

Returns:

Type Description
List[List[Tuple[str, float, int]]]

Updated history.

format_history_for_output(history, round_precision=4)

Format history for output in to_dict methods.

Parameters:

Name Type Description Default
history List[List[Tuple[str, float, int]]]

Raw history list.

required
round_precision int

Number of decimal places to round confidence values.

4

Returns:

Type Description
List[List[List[Any]]]

Formatted history with rounded confidence values.

Harness

Bases: LightningModule, PyTorchModelHubMixin

Main module that provides the scaffolding for the codebase.

tokenizer instance-attribute

Task Metrics usually consist of two types of metrics: 1. diagnostic metrics: These are typically different for different models as well as different tasks. 2. reported metrics: These are the same for all the models but different for different tasks. What we want too do is avoid a full blown (task x model) setup whenever we can but provide it as a last resort. The best case scenario is complete decopling. This happens when all the models adhere to the same output signature. But this never works for diagnostic metrics. In some cases, different tasks can share base metrics of both types. In these cases, we can use inheritance to avoid some code duplication. We would still have (task x model) number of classes though.

__init__(cfg, tokenizer=None, datamodule=None, write_per_sample_metrics=False, **kwargs)

Initialize the Harness module.

Parameters:

Name Type Description Default
cfg DictConfig

Configuration dictionary.

required
tokenizer Optional[Tokenizer]

Optional tokenizer instance.

None
datamodule Optional[BaseDataModule]

Optional datamodule instance.

None
write_per_sample_metrics bool

Whether to write per-sample metrics.

False
**kwargs Any

Additional keyword arguments.

{}

setup_metrics(cfg)

Attache metrics as modules

setup_post_hoc_evaluator(cfg)

Setup post-hoc evaluator. Can be use for tasks like molecule generation.

The post-hoc evaluator computes metrics on logged predictions at epoch end, enabling global metric computation (e.g., diversity on full generated set).

Parameters:

Name Type Description Default
cfg DictConfig

Configuration dictionary

required

create_lr_scheduler(optimizer, name, num_warmup_steps=None, fraction_warmup_steps=None, num_training_steps=None, interval='step', frequency=1, monitor='train_loss', strict=True, **kwargs) staticmethod

Creates a learning rate noise_schedule with the given configuration.

Parameters:

Name Type Description Default
name str

Huggingface name of the learning rate noise_schedule. https://huggingface.co/docs/transformers/en/main_classes/optimizer_schedules#transformers.get_scheduler

required
optimizer Optimizer

The optimizer to use with the noise_schedule

required
num_training_steps Optional[int]

The total number of training steps.

None
num_warmup_steps Optional[int]

The number of warmup steps.

None
fraction_warmup_steps Optional[float]

The fraction of training steps to use for warmup.

None
interval Literal['step', 'epoch']

The interval at which to update the learning rate.

'step'
frequency int

The frequency of the learning rate updates.

1
monitor Optional[str]

The metric to monitor for the learning rate noise_schedule.

'train_loss'
strict bool

Whether to strictly follow the learning rate schedule.

True
**kwargs Any

Additional keyword arguments to pass to the learning rate noise_schedule.

{}

Returns:

Name Type Description
LRSchedulerWithConfig LRSchedulerWithConfig

The configured learning rate scheduler.

prepare_batch_for_prediction(batch)

We need this for some tasks even if we have task sepecific collator, mainly because we want to clone some elements of the batch useful for computing metrics. TODO: Get rid of this method by cloning in the collator itself.

compute_loss(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Computes loss based on the dataloader name.

For 'lm', the loss function is applied. For 'prediction', the predictor's predict_step is used.

compute_post_hoc_metrics(split, dataloader_name, epoch, step, update_logged_predictions=True)

Compute post-hoc metrics on logged predictions.

Similar to compute_generative_perplexity, but for arbitrary post-hoc metrics. Loads predictions from jsonl, computes per-sample and global metrics, and logs aggregated results.

Parameters:

Name Type Description Default
split Literal['train', 'val', 'test', 'predict']

train/val/test/predict

required
dataloader_name str

Name of the dataloader

required
epoch int

Current epoch

required
step int

Current step

required
update_logged_predictions bool

If True, update predictions jsonl with per-sample metrics

True

Returns:

Type Description
Optional[Dict[str, Any]]

Dictionary of aggregated metrics, or None if no evaluator

extract_model_weights()

Extract current model state dict.

Returns:

Type Description
Dict[str, Any]

Model state dict (self.model.state_dict())

save_model_weights(path, overwrite=False)

Save current model weights to local file.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to save the model weights

required
overwrite bool

Whether to overwrite existing file

False

Raises:

Type Description
ValueError

If file exists and overwrite is False

load_model_weights(path, strict=True)

Load model weights from local file into self.model.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the model weights file

required
strict bool

Whether to strictly enforce that the keys match

True

load_model_from_hub(repo_id, revision=None, cache_dir=None, force_download=False, token=None, strict=True, **kwargs)

Download and load model weights from HuggingFace Hub into self.model.

This method downloads the model weights from the hub and loads them into the existing model. It does NOT reconstruct a new Harness instance.

Parameters:

Name Type Description Default
repo_id str

HuggingFace Hub repository ID (e.g., "username/model")

required
revision Optional[str]

Git revision (branch, tag, or commit)

None
cache_dir Optional[Union[str, Path]]

Directory to cache downloaded files

None
force_download bool

Force re-download even if cached

False
token Optional[Union[str, bool]]

HuggingFace Hub token for private repos

None
strict bool

Whether to strictly enforce that the keys match

True
**kwargs

Additional arguments for hf_hub_download

{}

from_checkpoint(checkpoint_path, cfg=None, tokenizer=None, datamodule=None, apply_ema=False, map_location='cpu', **kwargs) classmethod

Load Harness from Lightning checkpoint with optional EMA application.

This is the ONLY method that can apply EMA weights, as it has direct access to the checkpoint file containing the EMA state.

Parameters:

Name Type Description Default
checkpoint_path Union[str, Path]

Path to the Lightning checkpoint file

required
cfg Optional[DictConfig]

Optional config to override checkpoint config

None
tokenizer Optional[Tokenizer]

Optional tokenizer instance

None
datamodule Optional[BaseDataModule]

Optional datamodule instance

None
apply_ema bool

Whether to apply EMA weights from checkpoint

False
map_location str

Device to load checkpoint to

'cpu'
**kwargs

Additional arguments for load_from_checkpoint

{}

Returns:

Type Description
Harness

Harness instance with loaded weights (and EMA applied if requested)

Example

Load with EMA weights applied

harness = Harness.from_checkpoint( "checkpoint.ckpt", apply_ema=True, cfg=cfg, tokenizer=tokenizer, datamodule=datamodule )