mlm.datamodule_mlm
DefaultMLMCollator
Bases: Collator
Used for MLM pre-training with padded-truncated sequences.
Batch
- input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
- attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
- target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
- Padding is done on the right.
MLMSeq2SeqTrainCollator
Bases: Collator
MLM training for seq2seq tasks.
Batch
- input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
- attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
- target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
- Padding is done on the right.
MLMSeq2SeqCollator
Bases: Collator
MLM training for seq2seq tasks.
Batch
- input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
- attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
- target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
- There is padding on both sides because all prefixes end at the same position. TODO (efficiency): This is not ideal for seq2seq training as we will be wasting a lot of tokens in padding. For training, we should only pad on one side.
MLMSeq2SeqPredCollator
Bases: MLMSeq2SeqCollator
Input contains only the prefix and target_ids contain only the suffix if present.
MLMInfillWithExactTargetPredCollator
Bases: DefaultMLMCollator
Identical to DefaultMLMCollator but expects the prompt_ids to already contain masks.
prepare_prefix_ids(prefix_ids, pad_token_id, max_seq_len=None, truncate='block')
Prepare prefix ids for seq2seq tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prefix_ids
|
List[List[int]]
|
List[List[int]] |
required |
pad_token_id
|
int
|
int |
required |
max_seq_len
|
Optional[int]
|
Optional[int] |
None
|
truncate
|
Literal['max', 'block', None]
|
|
'block'
|
Note: Prefixes if truncated will be truncated from the left. Returns: Dict[str, TT]: input_ids: Integer[TT, " batch seq_len"] attention_mask: Integer[TT, " batch seq_len"]
prepare_prefix_suffix_ids(prefix_ids, suffix_ids, pad_token_id, mask_token_id, eos_token_id=None, bos_token_id=None, max_seq_len=None, truncate='block', loss_on_padding=True)
Prepare concatenated prefix and suffix ids for seq2seq tasks with padding on the right only
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_on_padding
|
bool
|
bool - If true, pad token is treated as a normal token: it has attention on it, it is predicted as a target token. - If false, it has no attention on it, it is not predicted as a target token (-100) |
True
|
print_batch_mlm(batch, split, tokenizer, dataloader_name='')
Print batch information for debugging MLM batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
The batch to print. |
required |
split
|
Literal['train', 'val', 'test', 'predict']
|
The split name. |
required |
tokenizer
|
Tokenizer
|
The tokenizer to decode tokens. |
required |
dataloader_name
|
str
|
Name of the dataloader. |
''
|