ARLM — Auto-Regressive Language Model
1. Overview
arlm is a standard left-to-right causal Transformer language model with rotary position embeddings: at each step the model predicts the next token given the prefix. The package supports both language-modeling and seq2seq tasks (prompt + target with a -100-ignored prompt region). See xlm-models/arlm/README.md for the high-level description.
2. Files at a glance
| Module | Public classes / helpers |
|---|---|
| model_arlm.py | RotaryTransformerARLMModel |
| loss_arlm.py | ARLMLoss |
| predictor_arlm.py | ARLMPredictor |
| datamodule_arlm.py | DefaultARLMCollator, ARLMSeq2SeqCollator, ARLMSeq2SeqPredCollator, ARLMEmptyDataset, prepare_prefix_ids_arlm, helpers |
| metrics_arlm.py | seq2seq_exact_match_update_fn, seq2seq_token_accuracy_update_fn, mean_metric_update_fn, perplexity_metric_update_fn, token_nll_metric_update_fn, sequence_length_metric_update_fn, valid_tokens_metric_update_fn |
| types_arlm.py | ARLMBatch, ARLMSeq2SeqBatch, ARLMSeq2SeqPredictionBatch, ARLMLossDict, ARLMModel (Protocol), ARLMPredictionDict |
3. Architecture
RotaryTransformerARLMModel(num_embeddings, d_model, num_layers, nhead, ...) is structurally similar to the MLM backbone (RoPE + RotaryTransformerLayers + RotaryTransformerFinalLayer), but the loss path always supplies a 3-D causal attention mask:
forward(
x_t: Integer[TT, " *batch seq_len"],
attention_mask: Optional[Bool[TT, " *batch seq_len seq_len"]] = None,
positions: Optional[Integer[TT, " *batch seq_len"]] = None,
token_type_ids: Optional[Integer[TT, " *batch seq_len"]] = None,
) -> Float[TT, " *batch seq_len vocab_size"]
attention_maskaccepts a(B, L, L)or(B, L)mask. The loss buildscausal_attention_mask = tril(ones(L, L)) & attention_mask.unsqueeze(1)so future tokens are masked out.positionsare computed asattention_mask.cumsum(dim=1) - 1(zeroed at padding) in the loss path.token_type_idsis accepted but currently passed straight through.
4. Batch contract
ARLMBatch (types_arlm.py):
| Field | Shape | Notes |
|---|---|---|
input_ids |
(B, L) int |
Tokens (prompt + target for seq2seq). |
attention_mask |
(B, L) int |
1 for real tokens, 0 for padding. |
target_ids |
(B, L) int |
Already left-shifted by 1; positions to ignore (prompt or padding) carry -100. |
Seq2seq variants (ARLMSeq2SeqBatch, ARLMSeq2SeqPredictionBatch) carry the same fields plus a token_type_ids channel for downstream code; the loss treats -100 as the universal ignore signal.
5. Loss
ARLMLoss(model, tokenizer):
configure(pl_module)is a no-op (nomask_token_id_tensorcache needed).loss_fn:- Builds
positions = (attention_mask.cumsum(dim=1) - 1) * attention_mask. - Builds
causal_attention_mask = attention_mask.unsqueeze(1) & tril(ones(L, L))of shape(B, L, L). - Runs
logits = model(input_ids, causal_attention_mask, positions). - Returns
loss = cross_entropy(logits.transpose(1, 2), target_ids, reduction="mean", ignore_index=-100).
Because target_ids already comes shifted by 1 (and prompt positions are -100), the loss reduces to next-token CE over the target segment.
6. Collators
BaseCollatorInput for LM training, Seq2SeqCollatorInput for seq2seq. The helper prepare_prefix_ids_arlm left-pads prompts and supports BOS/EOS placement.
| Class | Input | Output batch | Special behavior |
|---|---|---|---|
DefaultARLMCollator |
BaseCollatorInput |
ARLMBatch |
Pad-right to block_size, builds target_ids = input_ids[1:] + [-100] (last position and padding become -100). |
ARLMSeq2SeqCollator |
Seq2SeqCollatorInput |
ARLMSeq2SeqBatch |
Concatenates [prompt][BOS][target][EOS]; prompt positions in target_ids set to -100. |
ARLMSeq2SeqPredCollator |
Seq2SeqCollatorInput |
ARLMSeq2SeqPredictionBatch |
Left-padded prompt only as input_ids; target kept separate for evaluation. |
7. Predictor
ARLMPredictor(max_steps, max_length, tokenizer, noise_schedule=None, sampling_method="sample", top=1000, p=0.9, model):
sampling_method:"sample"-> argmax/sample from full logits."sample_top_k"->sample_from_top_k(top, logits)."sample_top_p"->sample_from_top_p(p, logits).predict()runs incremental greedy/sampled generation from the (left-padded) prompt, appending one token per step (or stopping early when EOS is produced ormax_lengthreached).- Output
ARLMPredictionDict:{text, text_with_spl_tokens, ids, attention_mask, positions, time_taken, output_start_idx}.
The noise_schedule argument is accepted for interface symmetry with the other families and is unused.
8. Metrics
All *_update_fn(batch, loss_dict, tokenizer=None). Worked examples: tests/models/arlm/test_metrics_arlm.py.
| Function | Returned keys | Notes |
|---|---|---|
seq2seq_exact_match_update_fn |
pred, target, pred_length=None, target_length=None |
pred = loss_dict["ids"][:, output_start_idx:]. |
seq2seq_token_accuracy_update_fn |
pred, target, pred_mask = ones_like(pred) |
All suffix tokens counted. |
mean_metric_update_fn |
value = loss_dict["loss"] |
Generic scalar. |
perplexity_metric_update_fn |
value = exp(nlls.mean()) |
Uses loss_dict["nlls"]. |
token_nll_metric_update_fn |
value = loss_dict["nlls"] |
Per-token NLL pass-through. |
sequence_length_metric_update_fn |
value = attention_mask.sum(dim=1).float() |
Per-example sequence length. |
valid_tokens_metric_update_fn |
value = (attention_mask & (target_ids != -100)).sum(dim=1).float() |
Excludes padding and ignored targets. |
9. Configs / experiments
Hydra groups under xlm-models/arlm/configs/. Available experiment entry points:
experiment=star_easy_arlmexperiment=tinygsm_arlm(seq2seq; TinyGSM)
TinyGSM
Task dataset and preprocessing: TinyGSM. GSM8K and code-execution eval: tinygsm_gsm8k.md.
| Experiment | experiment=tinygsm_arlm |
| Datamodule | tinygsm_arlm |
| Experiment YAML | tinygsm_arlm |
Training settings
| Setting | Value |
|---|---|
| Tokenizer | Qwen2-0.5B (Qwen/Qwen2-0.5B) with added <|mask|> |
block_size |
512 |
input_block_size |
0 |
| Batching | Per-device 32; global 512 |
| Collators | STAR seq2seq (seq2seq_* / seq2seq_pred_*); no BOS between question and code |
| Val / test prediction | Post-hoc code_exec_accuracy (Gsm8kCodeEval); token EM disabled |
| Monitored metric | val/lm/accumulated_loss |
| Training schedule | Up to 1M steps; validation every 50k steps; checkpoint every 2.5k steps (keep every 100k) |
Collators are reused from existing STAR seq2seq configs; no TinyGSM-specific collator YAMLs.
Commands
Prepare cache (rank 0 before multi-GPU training):
xlm job_type=prepare_data experiment=tinygsm_arlm num_dataset_workers=8
Managers that share full_name: TinyGSM/TinyGSM/train use distinct manual cache
directories via filter_suffix (e.g. val_holdout, pred_preprocess). After
changing prediction preprocess, rebuild with
datamodule.rewrite_manual_cache=true or delete the stale
.../TinyGSM/TinyGSM_pred_preprocess/train tree if needed.
On SLURM, see submit_prepare_data.py.
Train (DDP):
xlm job_name=tinygsm_arlm job_type=train experiment=tinygsm_arlm \
per_device_batch_size=32 trainer_strategy=ddp trainer.devices=8 trainer.num_nodes=1 \
++trainer.precision=bf16-mixed compile=False
Debug: set DEBUG_OVERFIT=1 to use full_name_debug (same HF split; short smoke runs).
Experiment results
Full W&B write-ups under docs/experiments/ are deferred until runs exist. Use experiment=tinygsm_arlm with the document experiment workflow when ready.
10. Testing
Tests in tests/models/arlm/:
test_model_arlm.py— extendsBaseModelTests, plus a causal-mask leakage test (added in this plan).test_loss_arlm.py— extendsBaseLossTests, plus a-100ignore-index test (added in this plan).test_collator_arlm.py— extendsBaseCollatorTestsand adds an ARLM-specific "target has ignore index" assertion.test_predictor_arlm.py— predictor smoke + vocab-range tests.test_metrics_arlm.py— pure-function helpers.
Shared fixtures (tiny_arlm_model, arlm_batch, simple_tokenizer) live in tests/conftest.py and tests/models/arlm/conftest.py.