Skip to content

mlm.types_mlm

MLMBatch

Bases: TypedDict

Input to the MLM.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The (possibly masked) input token ids.

attention_mask NotRequired[Tensor]

Boolean mask — shape (batch, seq_len) for standard padded batches (True = valid token). Omitted when model.use_flex_attn and using PackedMLMCollator (FlexAttention only).

target_ids Optional[Integer[Tensor, ' batch seq_len']]

Ground-truth token ids (masks replaced with original tokens).

positions Optional[Integer[Tensor, ' batch seq_len']]

Per-token RoPE positions. Required for packed FlexAttention batches; otherwise MLMLoss derives from the 1-D attention_mask.

segment_ids NotRequired[Integer[Tensor, ' batch seq_len']]

Packed batches only — per-token segment index (for mask_mod).

block_mask Optional[Any]

FlexAttention BlockMask from PackedMLMCollator when model.use_flex_attn=True.

fixed_positions_mask Optional[Bool[Tensor, ' batch seq_len']]

Optional boolean mask marking positions that should not be masked (used by infilling collators).

PackedFlexMLMBatch

Bases: TypedDict

Batch from PackedMLMCollator when model.use_flex_attn=True.

segment_ids are passed to MLMLoss.__call__ to build the FlexAttention BlockMask on the training device, avoiding pickling of locally-scoped mask_mod closures across DataLoader worker queues.

MLMSeq2SeqPredictionBatch

Bases: TypedDict

Input to the MLM for predicting suffix given the prefix.

MLMUncondtionalPredictionBatch

Input to the MLM for unconditional generation.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model. All masks.

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

MLMLossDict

Bases: TypedDict

Output of the LossFunction Callable.

Attributes:

Name Type Description
loss Float[Tensor, '']

The total loss value.

MLMPredictionDict

Bases: TypedDict

Output of the Predictor for MLM.

Attributes:

Name Type Description
loss Optional[Float[Tensor, batch]]

The loss value. Typically None.

text List[str]

The batch of generated text with special tokens.

ids Integer[Tensor, ' batch seq_len']

The batch of generated token_ids.

time_taken List[float]

Time taken for each prediction.

output_start_idx Integer[Tensor, ' batch']

The index of the first token in the output.

steps_taken List[int]

Number of steps taken per sample.