Skip to content

arlm.metrics_arlm

seq2seq_exact_match_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required

Note: We rely on having same number right pads in target and pred, which may not be true for ARLM.

seq2seq_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)

Parameters:

Name Type Description Default
batch Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"]

required
loss_dict Dict[str, Any]

Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"]

required

mean_metric_update_fn(batch, loss_dict, tokenizer=None)

Update function for mean loss metric.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
loss_dict Dict[str, Any]

Loss dictionary containing loss (since we don't use batch_loss for ARLM).

required

Returns:

Type Description
Dict[str, Any]

Dictionary with mean loss value.

perplexity_metric_update_fn(batch, loss_dict, tokenizer=None)

Update function for perplexity metric.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
loss_dict Dict[str, Any]

Loss dictionary containing nlls.

required

Returns:

Type Description
Dict[str, Any]

Dictionary with perplexity value.

token_nll_metric_update_fn(batch, loss_dict, tokenizer=None)

Update function for token-level negative log likelihood metric.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
loss_dict Dict[str, Any]

Loss dictionary containing nlls.

required

Returns:

Type Description
Dict[str, Any]

Dictionary with token-level NLL values.

sequence_length_metric_update_fn(batch, loss_dict, tokenizer=None)

Update function for sequence length metric.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
loss_dict Dict[str, Any]

Loss dictionary.

required

Returns:

Type Description
Dict[str, Any]

Dictionary with sequence length values.

valid_tokens_metric_update_fn(batch, loss_dict, tokenizer=None)

Update function for valid tokens count metric.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
loss_dict Dict[str, Any]

Loss dictionary.

required

Returns:

Type Description
Dict[str, Any]

Dictionary with valid tokens count.