mlm
MLM - Masked Language Model for XLM Framework
This package implements the MLM model with all necessary components: - Model architecture (model_mlm.py) - Loss function (loss_mlm.py) - Predictor for inference (predictor_mlm.py) - Data module (datamodule_mlm.py) - Metrics computation (metrics_mlm.py) - Type definitions (types_mlm.py) - History tracking (history_mlm.py)
This model was migrated from xlm.lm.mlm to be an external model.
RotaryTransformerMLMModel
Bases: Module, Model
Rotary embedding based transformer decoder.
forward(x_t, attention_mask=None, positions=None, token_type_ids=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Integer[Tensor, ' *batch seq_len']
|
The input tokens of shape (*batch, seq_len) |
required |
attention_mask
|
Optional[Bool[Tensor, ' *batch seq_len']]
|
The attention mask of shape (*batch, seq_len), which is True for non-padding tokens. |
None
|
positions
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
The positions of the tokens of shape (*batch, seq_len) |
None
|
MLMPredictor
Bases: Module, Predictor[MLMBatch, MLMPredictionDict]
Base predictor for MLM. Stochastically selects positions to unmask based on max_steps and max_new_tokens.
__init__(max_steps, max_new_tokens=None, tokenizer=None, model=None, noise_schedule=None, top_k=None, top_p=None, skip_special_tokens=True)
Initialize MLM Predictor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_steps
|
int
|
Maximum number of prediction steps. |
required |
tokenizer
|
Optional[Tokenizer]
|
Tokenizer for encoding/decoding. |
None
|
noise_schedule
|
Optional[NoiseSchedule]
|
Noise schedule for the diffusion process. |
None
|
top_k
|
Optional[int]
|
Top-k sampling parameter. |
None
|
top_p
|
Optional[float]
|
Top-p sampling parameter. |
None
|
model
|
Optional[MLMModel]
|
The MLM model to use for predictions. |
None
|
decode(results)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
results
|
MLMStepResults
|
x: Integer[TT, " batch seq_len"] Current predicted sequence. |
required |
Returns: out: List[str] Decoded sequence with special tokens. x: Integer[TT, " batch seq_len"] Current predicted sequence.
DefaultMLMCollator
Bases: Collator
Used for MLM pre-training with padded-truncated sequences.
Batch
- input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
- attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
- target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
- Padding is done on the right.
MLMSeq2SeqTrainCollator
Bases: Collator
MLM training for seq2seq tasks.
Batch
- input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
- attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
- target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
- Padding is done on the right.
MLMSeq2SeqCollator
Bases: Collator
MLM training for seq2seq tasks.
Batch
- input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
- attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
- target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
- There is padding on both sides because all prefixes end at the same position. TODO (efficiency): This is not ideal for seq2seq training as we will be wasting a lot of tokens in padding. For training, we should only pad on one side.
MLMSeq2SeqPredCollator
Bases: MLMSeq2SeqCollator
Input contains only the prefix and target_ids contain only the suffix if present.
MLMBatch
Bases: TypedDict
Input to the MLM. Attributes: input_ids (Integer[TT, " batch seq_len"]): The input ids to the model. attention_mask (Integer[TT, " batch seq_len"]): 1 for tokens that are not padding. target_ids (Optional[Integer[TT, " batch seq_len"]]): The target ids to the model.
MLMSeq2SeqPredictionBatch
Bases: TypedDict
Input to the MLM for predicting suffix given the prefix.
MLMUncondtionalPredictionBatch
Input to the MLM for unconditional generation.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
Integer[Tensor, ' batch seq_len']
|
The input ids to the model. All masks. |
attention_mask |
Integer[Tensor, ' batch seq_len']
|
1 for tokens that are not padding. |
MLMLossDict
Bases: TypedDict
Output of the LossFunction Callable.
Attributes:
| Name | Type | Description |
|---|---|---|
loss |
Float[Tensor, '']
|
The total loss value. |
MLMPredictionDict
Bases: TypedDict
Output of the Predictor for MLM.
Attributes:
| Name | Type | Description |
|---|---|---|
loss |
Optional[Float[Tensor, batch]]
|
The loss value. Typically None. |
text |
List[str]
|
The batch of generated text with special tokens. |
ids |
Integer[Tensor, ' batch seq_len']
|
The batch of generated token_ids. |
time_taken |
List[float]
|
Time taken for each prediction. |
output_start_idx |
Integer[Tensor, ' batch']
|
The index of the first token in the output. |
HistoryTopKPlugin
Bases: HistoryPluginBase
We will dump the top k tokens and probs as tensors in separate files for each step.