Skip to content

arlm.model_arlm

RotaryTransformerARLMModel

Bases: Module, Model

Rotary embedding based transformer decoder for auto-regressive language modeling.

__init__(num_embeddings, d_model, num_layers, nhead, padding_idx=0, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, rotary_emb_dim=64, max_length=1024, force_flash_attn=False, final_layer_without_normalization=False)

Initialize the ARLM transformer model.

Parameters:

Name Type Description Default
num_embeddings int

Size of the vocabulary.

required
d_model int

Dimension of the model.

required
num_layers int

Number of transformer layers.

required
nhead int

Number of attention heads.

required
padding_idx int

Index of the padding token.

0
dim_feedforward Optional[int]

Dimension of the feedforward network.

None
dropout float

Dropout rate.

0.1
activation str

Activation function.

'relu'
layer_norm_eps float

Epsilon for layer normalization.

1e-05
rotary_emb_dim int

Dimension of rotary embeddings.

64
max_length int

Maximum sequence length.

1024
force_flash_attn bool

Whether to force flash attention.

False
final_layer_without_normalization bool

Whether to use final layer without normalization.

False

forward(x_t, attention_mask=None, positions=None, token_type_ids=None)

Forward pass of the ARLM model.

Parameters:

Name Type Description Default
x_t Integer[Tensor, ' *batch seq_len']

The input tokens of shape (*batch, seq_len)

required
attention_mask Optional[Bool[Tensor, ' *batch seq_len seq_len']]

The attention mask of shape (batch, seq_len, seq_len) for full attention matrix, or (batch, seq_len) for simple mask. True for non-padding tokens.

None
positions Optional[Integer[Tensor, ' *batch seq_len']]

The positions of the tokens of shape (*batch, seq_len)

None
token_type_ids Optional[Integer[Tensor, ' *batch seq_len']]

The token type ids of shape (*batch, seq_len)

None

Returns:

Name Type Description
vocab_logits Float[Tensor, ' *batch seq_len vocab_size']

The vocabulary logits of shape (*batch, seq_len, vocab_size)

get_named_params_for_weight_decay()

Get parameters for weight decay (all parameters except biases and layer-norm parameters).

get_named_params_for_no_weight_decay()

Get parameters for no weight decay (biases and layer-norm parameters).