Skip to content

mlm.datamodule_mlm

DefaultMLMCollator

Bases: Collator

Used for MLM pre-training with padded-truncated sequences.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • Padding is done on the right.

MLMSeq2SeqTrainCollator

Bases: Collator

MLM training for seq2seq tasks.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • Padding is done on the right.

MLMSeq2SeqCollator

Bases: Collator

MLM training for seq2seq tasks.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • There is padding on both sides because all prefixes end at the same position. TODO (efficiency): This is not ideal for seq2seq training as we will be wasting a lot of tokens in padding. For training, we should only pad on one side.

MLMSeq2SeqPredCollator

Bases: MLMSeq2SeqCollator

Input contains only the prefix and target_ids contain only the suffix if present.

MLMInfillWithExactTargetPredCollator

Bases: DefaultMLMCollator

Identical to DefaultMLMCollator but expects the prompt_ids to already contain masks.

prepare_prefix_ids(prefix_ids, pad_token_id, max_seq_len=None, truncate='block')

Prepare prefix ids for seq2seq tasks.

Parameters:

Name Type Description Default
prefix_ids List[List[int]]

List[List[int]]

required
pad_token_id int

int

required
max_seq_len Optional[int]

Optional[int]

None
truncate Literal['max', 'block', None]
  • "max": Truncate to max(max_seq_len, max_in_batch).
    • when max_seq_len is not provided, it is the max in the batch.
  • "block": Pad-Truncate to max_seq_len.
  • None: Pad to max in the batch.
'block'

Note: Prefixes if truncated will be truncated from the left. Returns: Dict[str, TT]: input_ids: Integer[TT, " batch seq_len"] attention_mask: Integer[TT, " batch seq_len"]

prepare_prefix_suffix_ids(prefix_ids, suffix_ids, pad_token_id, mask_token_id, eos_token_id=None, bos_token_id=None, max_seq_len=None, truncate='block', loss_on_padding=True)

Prepare concatenated prefix and suffix ids for seq2seq tasks with padding on the right only

Parameters:

Name Type Description Default
loss_on_padding bool

bool - If true, pad token is treated as a normal token: it has attention on it, it is predicted as a target token. - If false, it has no attention on it, it is not predicted as a target token (-100)

True

print_batch_mlm(batch, split, tokenizer, dataloader_name='')

Print batch information for debugging MLM batches.

Parameters:

Name Type Description Default
batch Dict[str, Any]

The batch to print.

required
split Literal['train', 'val', 'test', 'predict']

The split name.

required
tokenizer Tokenizer

The tokenizer to decode tokens.

required
dataloader_name str

Name of the dataloader.

''