arlm
ARLM - Auto-Regressive Language Model for XLM Framework
This package implements the ARLM model with all necessary components: - Model architecture (model_arlm.py) - Loss function (loss_arlm.py) - Predictor for inference (predictor_arlm.py) - Data module (datamodule_arlm.py) - Metrics computation (metrics_arlm.py) - Type definitions (types_arlm.py)
This model was migrated from xlm.lm.arlm to be an external model.
RotaryTransformerARLMModel
Bases: Module, Model
Rotary embedding based transformer decoder for auto-regressive language modeling.
__init__(num_embeddings, d_model, num_layers, nhead, padding_idx=0, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, rotary_emb_dim=64, max_length=1024, force_flash_attn=False, final_layer_without_normalization=False)
Initialize the ARLM transformer model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_embeddings
|
int
|
Size of the vocabulary. |
required |
d_model
|
int
|
Dimension of the model. |
required |
num_layers
|
int
|
Number of transformer layers. |
required |
nhead
|
int
|
Number of attention heads. |
required |
padding_idx
|
int
|
Index of the padding token. |
0
|
dim_feedforward
|
Optional[int]
|
Dimension of the feedforward network. |
None
|
dropout
|
float
|
Dropout rate. |
0.1
|
activation
|
str
|
Activation function. |
'relu'
|
layer_norm_eps
|
float
|
Epsilon for layer normalization. |
1e-05
|
rotary_emb_dim
|
int
|
Dimension of rotary embeddings. |
64
|
max_length
|
int
|
Maximum sequence length. |
1024
|
force_flash_attn
|
bool
|
Whether to force flash attention. |
False
|
final_layer_without_normalization
|
bool
|
Whether to use final layer without normalization. |
False
|
forward(x_t, attention_mask=None, positions=None, token_type_ids=None)
Forward pass of the ARLM model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_t
|
Integer[Tensor, ' *batch seq_len']
|
The input tokens of shape (*batch, seq_len) |
required |
attention_mask
|
Optional[Bool[Tensor, ' *batch seq_len seq_len']]
|
The attention mask of shape (batch, seq_len, seq_len) for full attention matrix, or (batch, seq_len) for simple mask. True for non-padding tokens. |
None
|
positions
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
The positions of the tokens of shape (*batch, seq_len) |
None
|
token_type_ids
|
Optional[Integer[Tensor, ' *batch seq_len']]
|
The token type ids of shape (*batch, seq_len) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
vocab_logits |
Float[Tensor, ' *batch seq_len vocab_size']
|
The vocabulary logits of shape (*batch, seq_len, vocab_size) |
get_named_params_for_weight_decay()
Get parameters for weight decay (all parameters except biases and layer-norm parameters).
get_named_params_for_no_weight_decay()
Get parameters for no weight decay (biases and layer-norm parameters).
ARLMLoss
Bases: LossFunction[ARLMBatch, ARLMLossDict]
Loss function for Auto-Regressive Language Modeling (ARLM).
This loss function implements causal language modeling where the model predicts the next token given the previous tokens. The loss is computed using cross-entropy on the target sequence (which is already shifted in the batch).
For seq2seq tasks, loss is only computed on suffix tokens (non-prompt tokens).
__init__(model=None, tokenizer=None)
Initialize the ARLM loss function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Optional[ARLMModel]
|
The ARLM model to use for predictions. |
None
|
tokenizer
|
Optional[Tokenizer]
|
The tokenizer for processing tokens. |
None
|
configure(pl_module)
Configure the loss function with the lightning module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
Harness
|
The lightning module instance. |
required |
__call__(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Compute the loss for the given batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
ARLMBatch
|
The input batch containing input_ids, attention_mask, and target_ids. |
required |
batch_idx
|
Optional[int]
|
The batch index. |
None
|
dataloader_idx
|
Optional[int]
|
The dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
The dataloader name. |
None
|
Returns:
| Type | Description |
|---|---|
ARLMLossDict
|
Dictionary containing the loss, batch_loss, and nlls. |
loss_fn(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Compute the causal language modeling loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
ARLMBatch
|
The input batch. |
required |
batch_idx
|
Optional[int]
|
The batch index. |
None
|
dataloader_idx
|
Optional[int]
|
The dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
The dataloader name. |
None
|
Returns:
| Type | Description |
|---|---|
ARLMLossDict
|
Dictionary containing the loss, batch_loss, and nlls. |
ARLMPredictor
Bases: Module, Predictor[Dict[str, Any], ARLMPredictionDict]
__init__(max_steps, max_length, tokenizer=None, noise_schedule=None, sampling_method='sample', top=1000, p=0.9, model=None)
Constructor for ARLMPredictor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_steps
|
int
|
Maximum number of prediction steps. |
required |
max_length
|
int
|
Maximum sequence length. |
required |
tokenizer
|
Optional[Tokenizer]
|
The tokenizer to use. |
None
|
noise_schedule
|
Optional[NoiseSchedule]
|
Noise schedule (not used in ARLM but kept for interface consistency). |
None
|
sampling_method
|
Literal['sample', 'sample_top_k', 'sample_top_p']
|
Sampling method to use. |
'sample'
|
top
|
int
|
Top-k parameter for sampling. |
1000
|
p
|
float
|
Top-p parameter for sampling. |
0.9
|
model
|
Optional[ARLMModel]
|
The ARLM model to use for predictions. |
None
|
predict_single_step(step_results, current_step)
Predict the next token in the sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_results
|
ARLMStepResults
|
Current step results containing x, attention_mask, and logits. |
required |
current_step
|
int
|
Current prediction step. |
required |
final_step
|
Whether this is the final step. |
required |
Returns:
| Type | Description |
|---|---|
ARLMStepResults
|
Updated step results with the next token predicted. |
stop(step_results, current_length)
Check if prediction should stop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_results
|
ARLMStepResults
|
Current step results. |
required |
current_length
|
int
|
Current sequence length. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if prediction should stop, False otherwise. |
decode(results)
Decode the predicted sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
results
|
ARLMStepResults
|
Step results containing the predicted sequence. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[List[str], List[str], Integer[Tensor, ' batch seq_len']]
|
Tuple of (decoded_text, decoded_text_with_special_tokens, token_ids). |
predict(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None, max_len=0)
Predict the complete sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch containing input_ids and attention_mask. |
required |
batch_idx
|
Optional[int]
|
Batch index. |
None
|
dataloader_idx
|
Optional[int]
|
Dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
Dataloader name. |
None
|
max_len
|
int
|
Maximum length for prediction. |
0
|
Returns:
| Type | Description |
|---|---|
ARLMPredictionDict
|
Prediction results containing text, token IDs, and attention mask. |
to_dict(batch, preds, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Convert predictions to dictionary format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
preds
|
ARLMPredictionDict
|
Prediction results. |
required |
batch_idx
|
Optional[int]
|
Batch index. |
None
|
dataloader_idx
|
Optional[int]
|
Dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
Dataloader name. |
None
|
Returns:
| Type | Description |
|---|---|
List[Dict[str, Any]]
|
List of dictionaries containing prediction results. |
DefaultARLMCollator
Bases: Collator
Used for pre-training.
__init__(tokenizer, block_size, noise_schedule, truncate='block', add_eos=False)
Initialize the ARLM collator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokenizer
|
Tokenizer
|
The tokenizer to use. |
required |
block_size
|
int
|
Maximum sequence length. |
required |
noise_schedule
|
NoiseSchedule
|
Noise schedule (not used in ARLM but kept for interface consistency). |
required |
truncate
|
Literal['max', 'block', None]
|
Truncation strategy. |
'block'
|
add_eos
|
bool
|
Whether to add EOS token at the end of the sequence. |
False
|
__call__(examples)
Collate examples into a batch for ARLM training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples
|
List[BaseCollatorInput]
|
List of examples with input_ids. |
required |
Returns:
| Type | Description |
|---|---|
ARLMBatch
|
ARLMBatch with input_ids, attention_mask, and target_ids. |
ARLMSeq2SeqCollator
__init__(tokenizer, noise_schedule, block_size=None, input_block_size=None, add_bos=None, add_eos=False, truncate='block')
Initialize the ARLM sequence-to-sequence collator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokenizer
|
Tokenizer
|
The tokenizer to use. |
required |
noise_schedule
|
NoiseSchedule
|
Noise schedule (not used in ARLM but kept for interface consistency). |
required |
block_size
|
Optional[int]
|
Maximum sequence length for the target. |
None
|
input_block_size
|
Optional[int]
|
Maximum sequence length for the input. |
None
|
add_bos
|
Optional[str]
|
Where to add BOS token ("input" for prefix, "output" for after prefix, None for no BOS). |
None
|
add_eos
|
bool
|
Whether to add EOS token at the end of the suffix. |
False
|
truncate
|
Literal['max', 'block', None]
|
Truncation strategy. |
'block'
|
__call__(examples)
Collate examples into a batch for ARLM sequence-to-sequence training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples
|
List[Seq2SeqCollatorInput]
|
List of examples with prompt_ids and input_ids. |
required |
Returns:
| Type | Description |
|---|---|
ARLMSeq2SeqBatch
|
ARLMSeq2SeqBatch with input_ids, attention_mask, target_ids. |
ARLMSeq2SeqPredCollator
Bases: ARLMSeq2SeqCollator
Drops all the suffix/target tokens and sends them in the target_ids of shape (batch_size, target_seq_len)
__call__(examples)
Collate examples into a batch for ARLM sequence-to-sequence prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
examples
|
List[Seq2SeqCollatorInput]
|
List of examples with prompt_ids and input_ids. |
required |
Returns:
| Type | Description |
|---|---|
ARLMSeq2SeqBatch
|
ARLMSeq2SeqBatch with input_ids, attention_mask, target_ids. |
ARLMBatch
Bases: TypedDict
Input to the ARLM.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
Integer[Tensor, ' batch seq_len']
|
The input ids to the model. |
attention_mask |
Integer[Tensor, ' batch seq_len']
|
1 for tokens that are not padding. |
target_ids |
Integer[Tensor, ' batch seq_len']
|
The target ids for language modeling (shifted by 1). Positions with -100 are ignored during loss computation (prompt tokens or padding). |
ARLMSeq2SeqBatch
Bases: TypedDict
Input to the ARLM for sequence-to-sequence training.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
Integer[Tensor, ' batch seq_len']
|
The input ids to the model (prompt + target). |
attention_mask |
Integer[Tensor, ' batch seq_len']
|
1 for tokens that are not padding. |
token_type_ids |
Integer[Tensor, ' batch seq_len']
|
Token type ids (not used in ARLM but kept for interface consistency). |
target_ids |
Integer[Tensor, ' batch seq_len']
|
The target ids for language modeling (shifted by 1). Positions with -100 are ignored during loss computation (prompt tokens or padding). |
ARLMLossDict
Bases: TypedDict
Output of the LossFunction Callable.
Attributes:
| Name | Type | Description |
|---|---|---|
loss |
Float[Tensor, '']
|
The total loss value. |
batch_loss |
Float[Tensor, ' batch']
|
Loss value for each example in the batch. |
nlls |
Float[Tensor, ' num_tokens']
|
The negative log likelihoods of the real predicted tokens (non-pad, and masked in input). |
ARLMPredictionDict
Bases: TypedDict
Output of the Predictor for ARLM.
Attributes:
| Name | Type | Description |
|---|---|---|
text |
List[str]
|
The batch of generated text without special tokens. |
text_with_spl_tokens |
List[str]
|
The batch of generated text with special tokens. |
ids |
Integer[Tensor, ' batch seq_len']
|
The batch of generated token_ids. |
attention_mask |
Bool[Tensor, ' batch seq_len']
|
Attention mask accompanying the generated ids. |
positions |
Integer[Tensor, ' batch seq_len']
|
The batch of positions of the generated tokens accompanying the ids. |
time_taken |
List[float]
|
Time taken for each prediction. |
output_start_idx |
int
|
The index of the first output token. |