xlm.modules.ddit_simple
LayerNormAndScale
Bases: Module
Performs normalization and just scaling (no bias).
__init__(dim, eps=1e-05)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
the dimension of the input. |
required |
forward(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
the input tensor of shape (bsz, seq_len, dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
the normalized and scaled output tensor of shape (bsz, seq_len, dim). |
TimestepEmbedder
Bases: Module
Embeds scalar timesteps into vector representations.
__init__(hidden_size, frequency_embedding_size=256, max_period=10000)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
The size of the hidden layer and the output of MLP. |
required |
frequency_embedding_size
|
int
|
The size of the frequency embedding layer. |
256
|
forward(t)
Embeds scalar timesteps into vector representations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Tensor
|
A 1-D Tensor of bsz indices, one per batch element. These may be fractional. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
An (bsz, hidden_size) Tensor of positional embeddings. |
LabelEmbedder
Bases: Module
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
The number of classes. |
required |
cond_size
|
int
|
The size of the conditioning input. |
required |
label_dropout
|
Optional[float]
|
The dropout rate for class labels during training. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
embedding_table |
Embedding
|
The embedding table for class labels. |
num_classes |
int
|
The number of classes. |
drop_labels(labels)
Drop out class labels during training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
labels
|
Tensor
|
The input tensor of class labels of shape (bsz,). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: The modified class labels with some labels dropped by setting to the missing (last label). |
forward(labels)
Forward pass of the LabelEmbedder module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
labels
|
Tensor
|
The input tensor of class labels of shape (bsz,). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: The embedded vector representations of the class labels. |
AdaLNModulations
Bases: Module
Produces the modulation parameters for AdaLN.
__init__(cond_dim, dim, num_modulation_parameters=6)
Initializes the AdaLNModulations module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cond_dim
|
int
|
The dimension of the conditioning input. |
required |
dim
|
int
|
The hidden size. |
required |
forward(c)
Forward pass of the AdaLNModulations module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
c
|
Tensor
|
The conditioning input tensor. |
required |
Returns:
| Type | Description |
|---|---|
List[Tensor]
|
Tuple[torch.Tensor]: The modulation parameters for AdaLN. Each tensor has shape (bsz, 1, dim). When num_modulation_paramters=6, these tensors stand for the shift and scale parameters for the MHA and MLP layers, and the gating parameters: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp. |
ada_ln_modulate(x, shift, scale)
staticmethod
Applies adaLN modulation to the input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The input tensor of shape (bsz, seq_len, dim). |
required |
shift
|
Tensor
|
The shift parameter tensor of shape (bsz, 1, dim). |
required |
scale
|
Tensor
|
The scale parameter tensor of shape (bsz, 1, dim). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The modulated output tensor of shape (bsz, seq_len, dim). |
DDiTLayer
Bases: Module
One layer of DDiT.
It consists of a multi-head self-attention layer followed by a feedforward layer with adaLN and gating in between.
__init__(d_model, nhead, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, d_cond=None, force_flash_attn=False)
Initialize the DDiTBlock.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_model
|
int
|
the dimension of the input. |
required |
nhead
|
int
|
the number of attention heads. |
required |
d_cond
|
Optional[int]
|
the dimension of the conditioning input. |
None
|
mlp_ratio
|
the ratio of the hidden size of the MLP/feedforward layer to the input size. |
required | |
dropout
|
float
|
the dropout rate. |
0.1
|
forward(x, c, attention_mask, positions=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
the input tensor of shape (bsz, seq_len, dim). |
required |
c
|
Tensor
|
the conditioning input of shape (bsz, cond_dim). |
required |
attention_mask
|
Tensor
|
the attention mask of shape (bsz, seq_len), which is True for non-padding tokens. |
required |
apply_rotary_pos_emb(x, positions=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
the input tensor of shape (batch_size, seq_len, num_heads, dim). |
required |
Returns:
| Type | Description |
|---|---|
|
The tensor with rotary position embeddings applied to the first dim/2 of the last dimension. |
DDiTLayerList
Bases: ModuleList
A module list of DDiT blocks that share the rotary cache for the rotary embeddings.
DDitFinalLayer
Bases: Module
forward(x, c)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
the input tensor of shape (bsz, seq_len, dim). |
required |
c
|
Tensor
|
the conditioning input of shape (bsz, cond_dim). |
required |
add_bias_apply_dropout_scale(x, bias=None, dropout=0.0, scale=None, residual=None, training=True)
Adds bias, applies dropout, scales, and adds residual.
TODO: Consider creating fused implementation using jit and two wrappers Args: x: The input tensor of shape (bsz, seq_len, dim). bias: The bias tensor of shape (bsz, 1, dim). dropout: The dropout rate. scale: The scale tensor of shape (bsz, 1, dim). residual: The residual tensor of shape (bsz, seq_len, dim).
Returns:
| Type | Description |
|---|---|
Tensor
|
The output tensor of shape (bsz, seq_len, dim). |