Skip to content

xlm.modules.rotary_transformer

Modules for simple transformer decoder that uses rotary embeddings for positional encoding.

RotaryTransformerLayer

Bases: Module

One layer of DDiT.

It consists of a multi-head self-attention layer followed by a feedforward layer with normalization and gating in between.

__init__(d_model, nhead, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, force_flash_attn=False)

Initialize the DDiTBlock.

Parameters:

Name Type Description Default
d_model int

the dimension of the input.

required
nhead int

the number of attention heads.

required
mlp_ratio

the ratio of the hidden size of the MLP/feedforward layer to the input size.

required
dropout float

the dropout rate.

0.1

forward(inp, attention_mask, positions=None)

Parameters:

Name Type Description Default
x

the input tensor of shape (bsz, seq_len, dim).

required
attention_mask Tensor

the attention mask of shape (bsz, seq_len), which is True for non-padding tokens. It can also be of shape (bsz, seq_len (query), seq_len (key-value)), where the mask indicates which tokens are valid in the context.

required

apply_rotary_pos_emb(x, positions=None)

Parameters:

Name Type Description Default
x

the input tensor of shape (batch_size, seq_len, num_heads, dim).

required

Returns:

Type Description

The tensor with rotary position embeddings applied to the first dim/2 of the last dimension.

RotaryTransformerLayerList

Bases: ModuleList

A module list of DDiT blocks that share the rotary cache for the rotary embeddings.

RotaryTransformerFinalLayer

Bases: Module

Simple unembedding layer with optional layer norm.

forward(x)

Parameters:

Name Type Description Default
x Tensor

the input tensor of shape (bsz, seq_len, dim).

required

RotaryTransformerFinalLayerForClassification

Bases: Module

Feedforward layer with pre-norm and residual connection followed by a linear layer for classification.

forward(x)

Parameters:

Name Type Description Default
x Tensor

the input tensor of shape (bsz, seq_len, dim).

required

add_bias_apply_dropout_scale(x, bias=None, dropout=0.0, scale=None, residual=None, training=True)

Adds bias, applies dropout, scales, and adds residual.

TODO: Consider creating fused implementation using jit and two wrappers Args: x: The input tensor of shape (bsz, seq_len, dim). bias: The bias tensor of shape (bsz, 1, dim). dropout: The dropout rate. scale: The scale tensor of shape (bsz, 1, dim). residual: The residual tensor of shape (bsz, seq_len, dim).

Returns:

Type Description
Tensor

The output tensor of shape (bsz, seq_len, dim).