Skip to content

mdlm.metrics_mdlm

seq2seq_exact_match_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required

seq2seq_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required

mean_metric_update_fn(batch, loss_dict, tokenizer=None)

Update function for mean loss metric.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
loss_dict Dict[str, Any]

Loss dictionary containing loss (since we don't use batch_loss for ARLM).

required

Returns:

Type Description
Dict[str, Any]

Dictionary with mean loss value.