arlm.metrics_arlm
seq2seq_exact_match_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
Note: We rely on having same number right pads in target and pred, which may not be true for ARLM.
seq2seq_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
mean_metric_update_fn(batch, loss_dict, tokenizer=None)
Update function for mean loss metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
loss_dict
|
Dict[str, Any]
|
Loss dictionary containing loss (since we don't use batch_loss for ARLM). |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with mean loss value. |
perplexity_metric_update_fn(batch, loss_dict, tokenizer=None)
Update function for perplexity metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
loss_dict
|
Dict[str, Any]
|
Loss dictionary containing nlls. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with perplexity value. |
token_nll_metric_update_fn(batch, loss_dict, tokenizer=None)
Update function for token-level negative log likelihood metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
loss_dict
|
Dict[str, Any]
|
Loss dictionary containing nlls. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with token-level NLL values. |
sequence_length_metric_update_fn(batch, loss_dict, tokenizer=None)
Update function for sequence length metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
loss_dict
|
Dict[str, Any]
|
Loss dictionary. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with sequence length values. |
valid_tokens_metric_update_fn(batch, loss_dict, tokenizer=None)
Update function for valid tokens count metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
loss_dict
|
Dict[str, Any]
|
Loss dictionary. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with valid tokens count. |