xlm.metrics
MetricWrapper
Bases: Module
Unified metric wrapper that works with both Lightning trainer and Fabric.
Sends the raw batch and loss_dict output to the update_fn which transforms it into a dict of kwargs for the metric. The update_fn can contain task specific and model specific logic.
For Lightning: Use the log method to log metrics via LightningModule.
For Fabric: Use compute and get_log_dict methods for manual logging.
full_name
property
Get the full metric name with prefix.
update(batch, loss_dict, tokenizer=None)
Update the metric with the current batch and loss_dict.
log(pl_module, batch, metrics)
Log the metric using Lightning's logging mechanism.
compute()
Compute the current metric value. Useful for Fabric-based training.
get_log_dict()
Get a dictionary with the metric name and computed value for logging. Useful for Fabric-based training.
reset()
Reset the metric state.
ExactMatch
Bases: MeanMetric
update(pred, target, pred_length=None, target_length=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Integer[Tensor, ' *batch seq_len']
|
predicted tokens |
required |
target
|
Integer[Tensor, ' *batch seq_len']
|
target tokens |
required |
pred_length
|
Optional[Integer[Tensor, ' *batch']]
|
length of the predicted tokens |
None
|
target_length
|
Optional[Integer[Tensor, ' *batch']]
|
length of the target tokens |
None
|
TokenAccuracy
Bases: MeanMetric
update(pred, target, pred_mask=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Integer[Tensor, ' *batch seq_len']
|
predicted tokens |
required |
target
|
Integer[Tensor, ' *batch seq_len']
|
target tokens |
required |
pred_mask
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
True for positions that predicted. |
None
|
exact_match_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " *batch target_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
seq2seq_exact_match_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
Note: We rely on having same number right pads in target and pred, which may not be true for ARLM.
seq2seq_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |