Skip to content

mlm.model_mlm

RotaryTransformerMLMModel

Bases: Module, Model

Rotary embedding based transformer decoder.

forward(x_t, attention_mask=None, positions=None, token_type_ids=None)

Parameters:

Name Type Description Default
x_t Integer[Tensor, ' *batch seq_len']

The input tokens of shape (*batch, seq_len)

required
attention_mask Optional[Bool[Tensor, ' *batch seq_len']]

The attention mask of shape (*batch, seq_len), which is True for non-padding tokens.

None
positions Optional[Integer[Tensor, ' *batch seq_len']]

The positions of the tokens of shape (*batch, seq_len)

None