mlm.types_mlm
MLMBatch
Bases: TypedDict
Input to the MLM.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
Integer[Tensor, ' batch seq_len']
|
The (possibly masked) input token ids. |
attention_mask |
NotRequired[Tensor]
|
Boolean mask — shape |
target_ids |
Optional[Integer[Tensor, ' batch seq_len']]
|
Ground-truth token ids (masks replaced with original tokens). |
positions |
Optional[Integer[Tensor, ' batch seq_len']]
|
Per-token RoPE positions. Required for packed FlexAttention
batches; otherwise |
segment_ids |
NotRequired[Integer[Tensor, ' batch seq_len']]
|
Packed batches only — per-token segment index (for |
block_mask |
Optional[Any]
|
FlexAttention |
fixed_positions_mask |
Optional[Bool[Tensor, ' batch seq_len']]
|
Optional boolean mask marking positions that should not be masked (used by infilling collators). |
PackedFlexMLMBatch
Bases: TypedDict
Batch from PackedMLMCollator when model.use_flex_attn=True.
segment_ids are passed to MLMLoss.__call__ to build the FlexAttention
BlockMask on the training device, avoiding pickling of locally-scoped
mask_mod closures across DataLoader worker queues.
MLMSeq2SeqPredictionBatch
Bases: TypedDict
Input to the MLM for predicting suffix given the prefix.
MLMUncondtionalPredictionBatch
Input to the MLM for unconditional generation.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
Integer[Tensor, ' batch seq_len']
|
The input ids to the model. All masks. |
attention_mask |
Integer[Tensor, ' batch seq_len']
|
1 for tokens that are not padding. |
MLMLossDict
Bases: TypedDict
Output of the LossFunction Callable.
Attributes:
| Name | Type | Description |
|---|---|---|
loss |
Float[Tensor, '']
|
The total loss value. |
MLMPredictionDict
Bases: TypedDict
Output of the Predictor for MLM.
Attributes:
| Name | Type | Description |
|---|---|---|
loss |
Optional[Float[Tensor, batch]]
|
The loss value. Typically None. |
text |
List[str]
|
The batch of generated text with special tokens. |
ids |
Integer[Tensor, ' batch seq_len']
|
The batch of generated token_ids. |
time_taken |
List[float]
|
Time taken for each prediction. |
output_start_idx |
Integer[Tensor, ' batch']
|
The index of the first token in the output. |
steps_taken |
List[int]
|
Number of steps taken per sample. |