Skip to content

xlm.metrics

MetricWrapper

Bases: Module

Unified metric wrapper that works with both Lightning trainer and Fabric.

Sends the raw batch and loss_dict output to the update_fn which transforms it into a dict of kwargs for the metric. The update_fn can contain task specific and model specific logic.

For Lightning: Use the log method to log metrics via LightningModule. For Fabric: Use compute and get_log_dict methods for manual logging.

full_name property

Get the full metric name with prefix.

update(batch, loss_dict, tokenizer=None)

Update the metric with the current batch and loss_dict.

log(pl_module, batch, metrics)

Log the metric using Lightning's logging mechanism.

compute()

Compute the current metric value. Useful for Fabric-based training.

get_log_dict()

Get a dictionary with the metric name and computed value for logging. Useful for Fabric-based training.

reset()

Reset the metric state.

ExactMatch

Bases: MeanMetric

update(pred, target, pred_length=None, target_length=None)

Parameters:

Name Type Description Default
pred Integer[Tensor, ' *batch seq_len']

predicted tokens

required
target Integer[Tensor, ' *batch seq_len']

target tokens

required
pred_length Optional[Integer[Tensor, ' *batch']]

length of the predicted tokens

None
target_length Optional[Integer[Tensor, ' *batch']]

length of the target tokens

None

TokenAccuracy

Bases: MeanMetric

update(pred, target, pred_mask=None)

Parameters:

Name Type Description Default
pred Integer[Tensor, ' *batch seq_len']

predicted tokens

required
target Integer[Tensor, ' *batch seq_len']

target tokens

required
pred_mask Optional[Integer[Tensor, ' *batch seq_len']]

True for positions that predicted.

None

exact_match_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " *batch target_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required

seq2seq_exact_match_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required

Note: We rely on having same number right pads in target and pred, which may not be true for ARLM.

seq2seq_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required