Skip to content

xlm.modules.ddit_simple

LayerNormAndScale

Bases: Module

Performs normalization and just scaling (no bias).

__init__(dim, eps=1e-05)

Parameters:

Name Type Description Default
dim int

the dimension of the input.

required

forward(x)

Parameters:

Name Type Description Default
x Tensor

the input tensor of shape (bsz, seq_len, dim).

required

Returns:

Type Description
Tensor

the normalized and scaled output tensor of shape (bsz, seq_len, dim).

TimestepEmbedder

Bases: Module

Embeds scalar timesteps into vector representations.

__init__(hidden_size, frequency_embedding_size=256, max_period=10000)

Parameters:

Name Type Description Default
hidden_size int

The size of the hidden layer and the output of MLP.

required
frequency_embedding_size int

The size of the frequency embedding layer.

256

forward(t)

Embeds scalar timesteps into vector representations.

Parameters:

Name Type Description Default
t Tensor

A 1-D Tensor of bsz indices, one per batch element. These may be fractional.

required

Returns:

Type Description
Tensor

An (bsz, hidden_size) Tensor of positional embeddings.

LabelEmbedder

Bases: Module

Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.

Parameters:

Name Type Description Default
num_classes int

The number of classes.

required
cond_size int

The size of the conditioning input.

required
label_dropout Optional[float]

The dropout rate for class labels during training.

None

Attributes:

Name Type Description
embedding_table Embedding

The embedding table for class labels.

num_classes int

The number of classes.

drop_labels(labels)

Drop out class labels during training.

Parameters:

Name Type Description Default
labels Tensor

The input tensor of class labels of shape (bsz,).

required

Returns:

Type Description
Tensor

torch.Tensor: The modified class labels with some labels dropped by setting to the missing (last label).

forward(labels)

Forward pass of the LabelEmbedder module.

Parameters:

Name Type Description Default
labels Tensor

The input tensor of class labels of shape (bsz,).

required

Returns:

Type Description
Tensor

torch.Tensor: The embedded vector representations of the class labels.

AdaLNModulations

Bases: Module

Produces the modulation parameters for AdaLN.

__init__(cond_dim, dim, num_modulation_parameters=6)

Initializes the AdaLNModulations module.

Parameters:

Name Type Description Default
cond_dim int

The dimension of the conditioning input.

required
dim int

The hidden size.

required

forward(c)

Forward pass of the AdaLNModulations module.

Parameters:

Name Type Description Default
c Tensor

The conditioning input tensor.

required

Returns:

Type Description
List[Tensor]

Tuple[torch.Tensor]: The modulation parameters for AdaLN. Each tensor has shape (bsz, 1, dim). When num_modulation_paramters=6, these tensors stand for the shift and scale parameters for the MHA and MLP layers, and the gating parameters: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp.

ada_ln_modulate(x, shift, scale) staticmethod

Applies adaLN modulation to the input tensor.

Parameters:

Name Type Description Default
x Tensor

The input tensor of shape (bsz, seq_len, dim).

required
shift Tensor

The shift parameter tensor of shape (bsz, 1, dim).

required
scale Tensor

The scale parameter tensor of shape (bsz, 1, dim).

required

Returns:

Type Description
Tensor

The modulated output tensor of shape (bsz, seq_len, dim).

DDiTLayer

Bases: Module

One layer of DDiT.

It consists of a multi-head self-attention layer followed by a feedforward layer with adaLN and gating in between.

__init__(d_model, nhead, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, d_cond=None, force_flash_attn=False)

Initialize the DDiTBlock.

Parameters:

Name Type Description Default
d_model int

the dimension of the input.

required
nhead int

the number of attention heads.

required
d_cond Optional[int]

the dimension of the conditioning input.

None
mlp_ratio

the ratio of the hidden size of the MLP/feedforward layer to the input size.

required
dropout float

the dropout rate.

0.1

forward(x, c, attention_mask, positions=None)

Parameters:

Name Type Description Default
x Tensor

the input tensor of shape (bsz, seq_len, dim).

required
c Tensor

the conditioning input of shape (bsz, cond_dim).

required
attention_mask Tensor

the attention mask of shape (bsz, seq_len), which is True for non-padding tokens.

required

apply_rotary_pos_emb(x, positions=None)

Parameters:

Name Type Description Default
x

the input tensor of shape (batch_size, seq_len, num_heads, dim).

required

Returns:

Type Description

The tensor with rotary position embeddings applied to the first dim/2 of the last dimension.

DDiTLayerList

Bases: ModuleList

A module list of DDiT blocks that share the rotary cache for the rotary embeddings.

DDitFinalLayer

Bases: Module

forward(x, c)

Parameters:

Name Type Description Default
x Tensor

the input tensor of shape (bsz, seq_len, dim).

required
c Tensor

the conditioning input of shape (bsz, cond_dim).

required

add_bias_apply_dropout_scale(x, bias=None, dropout=0.0, scale=None, residual=None, training=True)

Adds bias, applies dropout, scales, and adds residual.

TODO: Consider creating fused implementation using jit and two wrappers Args: x: The input tensor of shape (bsz, seq_len, dim). bias: The bias tensor of shape (bsz, 1, dim). dropout: The dropout rate. scale: The scale tensor of shape (bsz, 1, dim). residual: The residual tensor of shape (bsz, seq_len, dim).

Returns:

Type Description
Tensor

The output tensor of shape (bsz, seq_len, dim).