mlm.metrics_mlm
exact_match_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " *batch target_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
infill_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
seq2seq_exact_match_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
seq2seq_token_accuracy_update_fn(batch, loss_dict, tokenizer=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "target_ids": Integer[TT, " batch target_seq_len"] - "input_ids": Integer[TT, " batch input_seq_len"] |
required |
loss_dict
|
Dict[str, Any]
|
Dict[str, Any]. Should contain the following keys: - "ids": Integer[TT, " *batch input_seq_len+target_seq_len"] |
required |
mean_metric_update_fn(batch, loss_dict, tokenizer=None)
Update function for mean loss metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
loss_dict
|
Dict[str, Any]
|
Loss dictionary containing loss (since we don't use batch_loss for ARLM). |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
Dictionary with mean loss value. |