Skip to content

mdlm.model_mdlm

MDLMModel

Bases: BaseMDLMModel

DDiT based transformer that represents time/noise using AdaLN and uses rotary positional embeddings.

forward(x_t, noise, attention_mask=None, positions=None)

Parameters:

Name Type Description Default
x_t Integer[Tensor, ' *batch seq_len']

The input tokens of shape (*batch, seq_len)

required
noise Float[Tensor, ' *batch']

The noise of shape (*batch)

required
attention_mask Optional[Bool[Tensor, ' *batch seq_len']]

The attention mask of shape (*batch, seq_len), which is True for non-padding tokens.

None
positions Optional[Integer[Tensor, ' *batch seq_len']]

The positions of the tokens of shape (*batch, seq_len)

None