mdlm.model_mdlm
MDLMModel
Bases: BaseMDLMModel
DDiT based transformer that represents time/noise using AdaLN and uses rotary positional embeddings.
forward(x_t, noise, attention_mask=None, positions=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Integer[Tensor, ' *batch seq_len']
|
The input tokens of shape (*batch, seq_len) |
required |
noise
|
Float[Tensor, ' *batch']
|
The noise of shape (*batch) |
required |
attention_mask
|
Optional[Bool[Tensor, ' *batch seq_len']]
|
The attention mask of shape (*batch, seq_len), which is True for non-padding tokens. |
None
|
positions
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
The positions of the tokens of shape (*batch, seq_len) |
None
|