Skip to content

FlexMDM — Flexible Masked Diffusion Model

Overview

flexmdm implements variable-length masked diffusion for text (linear insertion/unmasking noise, seq2seq and unconditional setups). Hydra configs live under xlm-models/flexmdm/configs/.

OWT-scale training and eval: OWT FlexMDM experiment.

TinyGSM

Task dataset and preprocessing: TinyGSM. GSM8K and code-execution eval: tinygsm_gsm8k.md.

Experiment experiment=tinygsm_flexmdm
Datamodule tinygsm_flexmdm
Experiment YAML tinygsm_flexmdm

Register a distinct pad token (pad_token: "<|pad|>" in global_components.tokenizer.special_tokens) so pad_token_id != eos_token_id; otherwise training and prediction will raise at init.

Training settings

Setting Value
Tokenizer Qwen2-0.5B (Qwen/Qwen2-0.5B) with added <|mask|>
block_size 512
input_block_size 0
Batching Per-device 32; global 512
Collators STAR seq2seq (seq2seq_* / seq2seq_pred_*); no BOS between question and code
Val / test prediction Post-hoc code_exec_accuracy (Gsm8kCodeEval); token EM disabled
Monitored metric val/lm/accumulated_loss
Training schedule Up to 1M steps; validation every 50k steps; checkpoint every 2.5k steps (keep every 100k)

Collators are reused from existing STAR seq2seq configs; no TinyGSM-specific collator YAMLs.

Commands

Prepare cache (rank 0 before multi-GPU training):

xlm job_type=prepare_data experiment=tinygsm_flexmdm num_dataset_workers=8

Managers that share full_name: TinyGSM/TinyGSM/train use distinct manual cache directories via filter_suffix (e.g. val_holdout, pred_preprocess). After changing prediction preprocess, rebuild with datamodule.rewrite_manual_cache=true or delete the stale .../TinyGSM/TinyGSM_pred_preprocess/train tree if needed.

On SLURM, see submit_prepare_data.py.

Train (DDP):

xlm job_name=tinygsm_flexmdm job_type=train experiment=tinygsm_flexmdm \
  per_device_batch_size=32 trainer_strategy=ddp trainer.devices=8 trainer.num_nodes=1 \
  ++trainer.precision=bf16-mixed compile=False

Debug overfit (one TinyGSM train example)

Configs: debug/overfit_tinygsm_flexmdm, datasets/tinygsm_debug_one, datasets/tinygsm_debug_one_pred.

Train and val/lm share one row (filter_suffix: debug_one); val/prediction uses prod tinygsm_pred_preprocess_fn on that row (debug_one_pred). Prefer this over generic debug=overfit for TinyGSM FlexMDM.

Prepare debug caches (once; num_dataset_workers=1 required for the first-row filter):

xlm job_type=prepare_data experiment=tinygsm_flexmdm debug=overfit_tinygsm_flexmdm \
  datamodule.rewrite_manual_cache=true num_dataset_workers=1

Debug train:

xlm job_type=train experiment=tinygsm_flexmdm debug=overfit_tinygsm_flexmdm

Experiment results

Full W&B write-ups under docs/experiments/ are deferred until runs exist. Use experiment=tinygsm_flexmdm with the document experiment workflow when ready.