xlm.modules.rotary_transformer
Modules for simple transformer decoder that uses rotary embeddings for positional encoding.
RotaryTransformerLayer
Bases: Module
One layer of DDiT.
It consists of a multi-head self-attention layer followed by a feedforward layer with normalization and gating in between.
__init__(d_model, nhead, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, force_flash_attn=False)
Initialize the DDiTBlock.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_model
|
int
|
the dimension of the input. |
required |
nhead
|
int
|
the number of attention heads. |
required |
mlp_ratio
|
the ratio of the hidden size of the MLP/feedforward layer to the input size. |
required | |
dropout
|
float
|
the dropout rate. |
0.1
|
forward(inp, attention_mask, positions=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
the input tensor of shape (bsz, seq_len, dim). |
required | |
attention_mask
|
Tensor
|
the attention mask of shape (bsz, seq_len), which is True for non-padding tokens. It can also be of shape (bsz, seq_len (query), seq_len (key-value)), where the mask indicates which tokens are valid in the context. |
required |
apply_rotary_pos_emb(x, positions=None)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
the input tensor of shape (batch_size, seq_len, num_heads, dim). |
required |
Returns:
| Type | Description |
|---|---|
|
The tensor with rotary position embeddings applied to the first dim/2 of the last dimension. |
RotaryTransformerLayerList
Bases: ModuleList
A module list of DDiT blocks that share the rotary cache for the rotary embeddings.
RotaryTransformerFinalLayer
Bases: Module
Simple unembedding layer with optional layer norm.
forward(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
the input tensor of shape (bsz, seq_len, dim). |
required |
RotaryTransformerFinalLayerForClassification
Bases: Module
Feedforward layer with pre-norm and residual connection followed by a linear layer for classification.
forward(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
the input tensor of shape (bsz, seq_len, dim). |
required |
add_bias_apply_dropout_scale(x, bias=None, dropout=0.0, scale=None, residual=None, training=True)
Adds bias, applies dropout, scales, and adds residual.
TODO: Consider creating fused implementation using jit and two wrappers Args: x: The input tensor of shape (bsz, seq_len, dim). bias: The bias tensor of shape (bsz, 1, dim). dropout: The dropout rate. scale: The scale tensor of shape (bsz, 1, dim). residual: The residual tensor of shape (bsz, seq_len, dim).
Returns:
| Type | Description |
|---|---|
Tensor
|
The output tensor of shape (bsz, seq_len, dim). |