Skip to content

mlm

MLM - Masked Language Model for XLM Framework

This package implements the MLM model with all necessary components: - Model architecture (model_mlm.py) - Loss function (loss_mlm.py) - Predictor for inference (predictor_mlm.py) - Data module (datamodule_mlm.py) - Metrics computation (metrics_mlm.py) - Type definitions (types_mlm.py) - History tracking (history_mlm.py)

This model was migrated from xlm.lm.mlm to be an external model.

RotaryTransformerMLMModel

Bases: Module, Model

Rotary embedding based transformer decoder.

forward(x_t, attention_mask=None, positions=None, block_mask=None)

Parameters:

Name Type Description Default
x_t Integer[Tensor, ' *batch seq_len']

The input tokens of shape (*batch, seq_len)

required
attention_mask Optional[Bool[Tensor, ' *batch seq_len']]

Boolean mask (*batch, seq_len) — valid (non-padding) tokens. Ignored when block_mask is set (packed + FlexAttention).

None
positions Optional[Integer[Tensor, ' *batch seq_len']]

Per-token RoPE positions (*batch, seq_len). Required for packed sequences (reset at each segment boundary).

None
block_mask

FlexAttention BlockMask for packed training when use_flex_attn=True. When set, attention_mask is ignored.

None

MLMPredictor

Bases: Module, Predictor[MLMBatch, MLMPredictionDict]

Base predictor for MLM. Stochastically selects positions to unmask based on max_steps and max_new_tokens.

__init__(max_steps, max_new_tokens=None, tokenizer=None, model=None, noise_schedule=None, top_k=None, top_p=None, confidence=None, threshold=None, skip_special_tokens=True)

Initialize MLM Predictor.

Parameters:

Name Type Description Default
max_steps int

Maximum number of prediction steps.

required
tokenizer Optional[Tokenizer]

Tokenizer for encoding/decoding.

None
noise_schedule Optional[NoiseSchedule]

Noise schedule for the diffusion process.

None
top_k Optional[int]

Top-k sampling parameter.

None
top_p Optional[float]

Top-p sampling parameter.

None
confidence Optional[Literal['prob_diff', 'entropy', 'top_prob']]

Confidence-based position sampling parameter.

None
threshold Optional[float]

Threshold for confidence-based position sampling.

None
model Optional[MLMModel]

The MLM model to use for predictions.

None

decode(results)

Parameters:

Name Type Description Default
results MLMStepResults

x: Integer[TT, " batch seq_len"] Current predicted sequence.

required

Returns: out: List[str] Decoded sequence with special tokens. x: Integer[TT, " batch seq_len"] Current predicted sequence.

DefaultMLMCollator

Bases: Collator

Used for MLM pre-training with padded-truncated sequences.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • Padding is done on the right.

MLMSeq2SeqTrainCollator

Bases: Collator

MLM training for seq2seq tasks.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • Padding is done on the right.

MLMSeq2SeqCollator

Bases: Collator

MLM training for seq2seq tasks.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • There is padding on both sides because all prefixes end at the same position. TODO (efficiency): This is not ideal for seq2seq training as we will be wasting a lot of tokens in padding. For training, we should only pad on one side.

MLMSeq2SeqPredCollator

Bases: MLMSeq2SeqCollator

Input contains only the prefix and target_ids contain only the suffix if present.

MLMBatch

Bases: TypedDict

Input to the MLM.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The (possibly masked) input token ids.

attention_mask NotRequired[Tensor]

Boolean mask — shape (batch, seq_len) for standard padded batches (True = valid token). Omitted when model.use_flex_attn and using PackedMLMCollator (FlexAttention only).

target_ids Optional[Integer[Tensor, ' batch seq_len']]

Ground-truth token ids (masks replaced with original tokens).

positions Optional[Integer[Tensor, ' batch seq_len']]

Per-token RoPE positions. Required for packed FlexAttention batches; otherwise MLMLoss derives from the 1-D attention_mask.

segment_ids NotRequired[Integer[Tensor, ' batch seq_len']]

Packed batches only — per-token segment index (for mask_mod).

block_mask Optional[Any]

FlexAttention BlockMask from PackedMLMCollator when model.use_flex_attn=True.

fixed_positions_mask Optional[Bool[Tensor, ' batch seq_len']]

Optional boolean mask marking positions that should not be masked (used by infilling collators).

PackedFlexMLMBatch

Bases: TypedDict

Batch from PackedMLMCollator when model.use_flex_attn=True.

segment_ids are passed to MLMLoss.__call__ to build the FlexAttention BlockMask on the training device, avoiding pickling of locally-scoped mask_mod closures across DataLoader worker queues.

MLMSeq2SeqPredictionBatch

Bases: TypedDict

Input to the MLM for predicting suffix given the prefix.

MLMUncondtionalPredictionBatch

Input to the MLM for unconditional generation.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model. All masks.

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

MLMLossDict

Bases: TypedDict

Output of the LossFunction Callable.

Attributes:

Name Type Description
loss Float[Tensor, '']

The total loss value.

MLMPredictionDict

Bases: TypedDict

Output of the Predictor for MLM.

Attributes:

Name Type Description
loss Optional[Float[Tensor, batch]]

The loss value. Typically None.

text List[str]

The batch of generated text with special tokens.

ids Integer[Tensor, ' batch seq_len']

The batch of generated token_ids.

time_taken List[float]

Time taken for each prediction.

output_start_idx Integer[Tensor, ' batch']

The index of the first token in the output.

steps_taken List[int]

Number of steps taken per sample.

HistoryTopKPlugin

Bases: HistoryPluginBase

We will dump the top k tokens and probs as tensors in separate files for each step.