Skip to content

arlm

ARLM - Auto-Regressive Language Model for XLM Framework

This package implements the ARLM model with all necessary components: - Model architecture (model_arlm.py) - Loss function (loss_arlm.py) - Predictor for inference (predictor_arlm.py) - Data module (datamodule_arlm.py) - Metrics computation (metrics_arlm.py) - Type definitions (types_arlm.py)

This model was migrated from xlm.lm.arlm to be an external model.

RotaryTransformerARLMModel

Bases: Module, Model

Rotary embedding based transformer decoder for auto-regressive language modeling.

__init__(num_embeddings, d_model, num_layers, nhead, padding_idx=0, dim_feedforward=None, dropout=0.1, activation='relu', layer_norm_eps=1e-05, rotary_emb_dim=64, max_length=1024, force_flash_attn=False, final_layer_without_normalization=False)

Initialize the ARLM transformer model.

Parameters:

Name Type Description Default
num_embeddings int

Size of the vocabulary.

required
d_model int

Dimension of the model.

required
num_layers int

Number of transformer layers.

required
nhead int

Number of attention heads.

required
padding_idx int

Index of the padding token.

0
dim_feedforward Optional[int]

Dimension of the feedforward network.

None
dropout float

Dropout rate.

0.1
activation str

Activation function.

'relu'
layer_norm_eps float

Epsilon for layer normalization.

1e-05
rotary_emb_dim int

Dimension of rotary embeddings.

64
max_length int

Maximum sequence length.

1024
force_flash_attn bool

Whether to force flash attention.

False
final_layer_without_normalization bool

Whether to use final layer without normalization.

False

forward(x_t, attention_mask=None, positions=None, token_type_ids=None)

Forward pass of the ARLM model.

Parameters:

Name Type Description Default
x_t Integer[Tensor, ' *batch seq_len']

The input tokens of shape (*batch, seq_len)

required
attention_mask Optional[Bool[Tensor, ' *batch seq_len seq_len']]

The attention mask of shape (batch, seq_len, seq_len) for full attention matrix, or (batch, seq_len) for simple mask. True for non-padding tokens.

None
positions Optional[Integer[Tensor, ' *batch seq_len']]

The positions of the tokens of shape (*batch, seq_len)

None
token_type_ids Optional[Integer[Tensor, ' *batch seq_len']]

The token type ids of shape (*batch, seq_len)

None

Returns:

Name Type Description
vocab_logits Float[Tensor, ' *batch seq_len vocab_size']

The vocabulary logits of shape (*batch, seq_len, vocab_size)

get_named_params_for_weight_decay()

Get parameters for weight decay (all parameters except biases and layer-norm parameters).

get_named_params_for_no_weight_decay()

Get parameters for no weight decay (biases and layer-norm parameters).

ARLMLoss

Bases: LossFunction[ARLMBatch, ARLMLossDict]

Loss function for Auto-Regressive Language Modeling (ARLM).

This loss function implements causal language modeling where the model predicts the next token given the previous tokens. The loss is computed using cross-entropy on the target sequence (which is already shifted in the batch).

For seq2seq tasks, loss is only computed on suffix tokens (non-prompt tokens).

__init__(model=None, tokenizer=None)

Initialize the ARLM loss function.

Parameters:

Name Type Description Default
model Optional[ARLMModel]

The ARLM model to use for predictions.

None
tokenizer Optional[Tokenizer]

The tokenizer for processing tokens.

None

configure(pl_module)

Configure the loss function with the lightning module.

Parameters:

Name Type Description Default
pl_module Harness

The lightning module instance.

required

__call__(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Compute the loss for the given batch.

Parameters:

Name Type Description Default
batch ARLMBatch

The input batch containing input_ids, attention_mask, and target_ids.

required
batch_idx Optional[int]

The batch index.

None
dataloader_idx Optional[int]

The dataloader index.

None
dataloader_name Optional[str]

The dataloader name.

None

Returns:

Type Description
ARLMLossDict

Dictionary containing the loss, batch_loss, and nlls.

loss_fn(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Compute the causal language modeling loss.

Parameters:

Name Type Description Default
batch ARLMBatch

The input batch.

required
batch_idx Optional[int]

The batch index.

None
dataloader_idx Optional[int]

The dataloader index.

None
dataloader_name Optional[str]

The dataloader name.

None

Returns:

Type Description
ARLMLossDict

Dictionary containing the loss, batch_loss, and nlls.

ARLMPredictor

Bases: Module, Predictor[Dict[str, Any], ARLMPredictionDict]

__init__(max_steps, max_length, tokenizer=None, noise_schedule=None, sampling_method='sample', top=1000, p=0.9, model=None)

Constructor for ARLMPredictor.

Parameters:

Name Type Description Default
max_steps int

Maximum number of prediction steps.

required
max_length int

Maximum sequence length.

required
tokenizer Optional[Tokenizer]

The tokenizer to use.

None
noise_schedule Optional[NoiseSchedule]

Noise schedule (not used in ARLM but kept for interface consistency).

None
sampling_method Literal['sample', 'sample_top_k', 'sample_top_p']

Sampling method to use.

'sample'
top int

Top-k parameter for sampling.

1000
p float

Top-p parameter for sampling.

0.9
model Optional[ARLMModel]

The ARLM model to use for predictions.

None

predict_single_step(step_results, current_step)

Predict the next token in the sequence.

Parameters:

Name Type Description Default
step_results ARLMStepResults

Current step results containing x, attention_mask, and logits.

required
current_step int

Current prediction step.

required
final_step

Whether this is the final step.

required

Returns:

Type Description
ARLMStepResults

Updated step results with the next token predicted.

stop(step_results, current_length)

Check if prediction should stop.

Parameters:

Name Type Description Default
step_results ARLMStepResults

Current step results.

required
current_length int

Current sequence length.

required

Returns:

Type Description
bool

True if prediction should stop, False otherwise.

decode(results)

Decode the predicted sequence.

Parameters:

Name Type Description Default
results ARLMStepResults

Step results containing the predicted sequence.

required

Returns:

Type Description
Tuple[List[str], List[str], Integer[Tensor, ' batch seq_len']]

Tuple of (decoded_text, decoded_text_with_special_tokens, token_ids).

predict(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None, max_len=0)

Predict the complete sequence.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch containing input_ids and attention_mask.

required
batch_idx Optional[int]

Batch index.

None
dataloader_idx Optional[int]

Dataloader index.

None
dataloader_name Optional[str]

Dataloader name.

None
max_len int

Maximum length for prediction.

0

Returns:

Type Description
ARLMPredictionDict

Prediction results containing text, token IDs, and attention mask.

to_dict(batch, preds, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Convert predictions to dictionary format.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
preds ARLMPredictionDict

Prediction results.

required
batch_idx Optional[int]

Batch index.

None
dataloader_idx Optional[int]

Dataloader index.

None
dataloader_name Optional[str]

Dataloader name.

None

Returns:

Type Description
List[Dict[str, Any]]

List of dictionaries containing prediction results.

DefaultARLMCollator

Bases: Collator

Used for pre-training.

__init__(tokenizer, block_size, noise_schedule, truncate='block', add_eos=False)

Initialize the ARLM collator.

Parameters:

Name Type Description Default
tokenizer Tokenizer

The tokenizer to use.

required
block_size int

Maximum sequence length.

required
noise_schedule NoiseSchedule

Noise schedule (not used in ARLM but kept for interface consistency).

required
truncate Literal['max', 'block', None]

Truncation strategy.

'block'
add_eos bool

Whether to add EOS token at the end of the sequence.

False

__call__(examples)

Collate examples into a batch for ARLM training.

Parameters:

Name Type Description Default
examples List[BaseCollatorInput]

List of examples with input_ids.

required

Returns:

Type Description
ARLMBatch

ARLMBatch with input_ids, attention_mask, and target_ids.

ARLMSeq2SeqCollator

__init__(tokenizer, noise_schedule, block_size=None, input_block_size=None, add_bos=None, add_eos=False, truncate='block')

Initialize the ARLM sequence-to-sequence collator.

Parameters:

Name Type Description Default
tokenizer Tokenizer

The tokenizer to use.

required
noise_schedule NoiseSchedule

Noise schedule (not used in ARLM but kept for interface consistency).

required
block_size Optional[int]

Maximum sequence length for the target.

None
input_block_size Optional[int]

Maximum sequence length for the input.

None
add_bos Optional[str]

Where to add BOS token ("input" for prefix, "output" for after prefix, None for no BOS).

None
add_eos bool

Whether to add EOS token at the end of the suffix.

False
truncate Literal['max', 'block', None]

Truncation strategy.

'block'

__call__(examples)

Collate examples into a batch for ARLM sequence-to-sequence training.

Parameters:

Name Type Description Default
examples List[Seq2SeqCollatorInput]

List of examples with prompt_ids and input_ids.

required

Returns:

Type Description
ARLMSeq2SeqBatch

ARLMSeq2SeqBatch with input_ids, attention_mask, target_ids.

ARLMSeq2SeqPredCollator

Bases: ARLMSeq2SeqCollator

Drops all the suffix/target tokens and sends them in the target_ids of shape (batch_size, target_seq_len)

__call__(examples)

Collate examples into a batch for ARLM sequence-to-sequence prediction.

Parameters:

Name Type Description Default
examples List[Seq2SeqCollatorInput]

List of examples with prompt_ids and input_ids.

required

Returns:

Type Description
ARLMSeq2SeqBatch

ARLMSeq2SeqBatch with input_ids, attention_mask, target_ids.

ARLMBatch

Bases: TypedDict

Input to the ARLM.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model.

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

target_ids Integer[Tensor, ' batch seq_len']

The target ids for language modeling (shifted by 1). Positions with -100 are ignored during loss computation (prompt tokens or padding).

ARLMSeq2SeqBatch

Bases: TypedDict

Input to the ARLM for sequence-to-sequence training.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model (prompt + target).

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

token_type_ids Integer[Tensor, ' batch seq_len']

Token type ids (not used in ARLM but kept for interface consistency).

target_ids Integer[Tensor, ' batch seq_len']

The target ids for language modeling (shifted by 1). Positions with -100 are ignored during loss computation (prompt tokens or padding).

ARLMLossDict

Bases: TypedDict

Output of the LossFunction Callable.

Attributes:

Name Type Description
loss Float[Tensor, '']

The total loss value.

batch_loss Float[Tensor, ' batch']

Loss value for each example in the batch.

nlls Float[Tensor, ' num_tokens']

The negative log likelihoods of the real predicted tokens (non-pad, and masked in input).

ARLMPredictionDict

Bases: TypedDict

Output of the Predictor for ARLM.

Attributes:

Name Type Description
text List[str]

The batch of generated text without special tokens.

text_with_spl_tokens List[str]

The batch of generated text with special tokens.

ids Integer[Tensor, ' batch seq_len']

The batch of generated token_ids.

attention_mask Bool[Tensor, ' batch seq_len']

Attention mask accompanying the generated ids.

positions Integer[Tensor, ' batch seq_len']

The batch of positions of the generated tokens accompanying the ids.

time_taken List[float]

Time taken for each prediction.

output_start_idx int

The index of the first output token.