Skip to content

mlm.model_mlm

RotaryTransformerMLMModel

Bases: Module, Model

Rotary embedding based transformer decoder.

forward(x_t, attention_mask=None, positions=None, block_mask=None)

Parameters:

Name Type Description Default
x_t Integer[Tensor, ' *batch seq_len']

The input tokens of shape (*batch, seq_len)

required
attention_mask Optional[Bool[Tensor, ' *batch seq_len']]

Boolean mask (*batch, seq_len) — valid (non-padding) tokens. Ignored when block_mask is set (packed + FlexAttention).

None
positions Optional[Integer[Tensor, ' *batch seq_len']]

Per-token RoPE positions (*batch, seq_len). Required for packed sequences (reset at each segment boundary).

None
block_mask

FlexAttention BlockMask for packed training when use_flex_attn=True. When set, attention_mask is ignored.

None