arlm.model_arlm
RotaryTransformerARLMModel
Bases: Module, Model
Rotary embedding based transformer decoder for auto-regressive language modeling.
__init__(num_embeddings, d_model, num_layers, nhead, padding_idx=0, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, rotary_emb_dim=64, max_length=1024, force_flash_attn=False, final_layer_without_normalization=False)
Initialize the ARLM transformer model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_embeddings
|
int
|
Size of the vocabulary. |
required |
d_model
|
int
|
Dimension of the model. |
required |
num_layers
|
int
|
Number of transformer layers. |
required |
nhead
|
int
|
Number of attention heads. |
required |
padding_idx
|
int
|
Index of the padding token. |
0
|
dim_feedforward
|
Optional[int]
|
Dimension of the feedforward network. |
None
|
dropout
|
float
|
Dropout rate. |
0.1
|
activation
|
str
|
Activation function. |
'relu'
|
layer_norm_eps
|
float
|
Epsilon for layer normalization. |
1e-05
|
rotary_emb_dim
|
int
|
Dimension of rotary embeddings. |
64
|
max_length
|
int
|
Maximum sequence length. |
1024
|
force_flash_attn
|
bool
|
Whether to force flash attention. |
False
|
final_layer_without_normalization
|
bool
|
Whether to use final layer without normalization. |
False
|
forward(x_t, attention_mask=None, positions=None, token_type_ids=None)
Forward pass of the ARLM model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Integer[Tensor, ' *batch seq_len']
|
The input tokens of shape (*batch, seq_len) |
required |
attention_mask
|
Optional[Bool[Tensor, ' *batch seq_len seq_len']]
|
The attention mask of shape (batch, seq_len, seq_len) for full attention matrix, or (batch, seq_len) for simple mask. True for non-padding tokens. |
None
|
positions
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
The positions of the tokens of shape (*batch, seq_len) |
None
|
token_type_ids
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
The token type ids of shape (*batch, seq_len) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
vocab_logits |
Float[Tensor, ' *batch seq_len vocab_size']
|
The vocabulary logits of shape (*batch, seq_len, vocab_size) |
get_named_params_for_weight_decay()
Get parameters for weight decay (all parameters except biases and layer-norm parameters).
get_named_params_for_no_weight_decay()
Get parameters for no weight decay (biases and layer-norm parameters).