Skip to content

OWT FlexMDM

See also: FlexMDM package · Task docs

Dataset

Experiment config: experiment=owt_flexmdm (datamodule: owt_flexmdm).

Training and validation use OpenWebText, pre-tokenized with GPT-2 and filtered to sequences of at most 1,024 tokens. The processed split is hosted on Hugging Face as dhruveshpatel/owt-gpt2-1024-split: a 10k-example validation holdout (seed 2357) and the remainder for training.

Setting Value
Tokenizer GPT-2 (gpt2)
Block size 1,024
Batching Per-device batch size 32; global batch size 512
Train split dhruveshpatel/owt-gpt2-1024-split/train
Val split dhruveshpatel/owt-gpt2-1024-split/validation
Train collator FlexMDMTrainCollator (linear insertion/unmasking noise; variable-length segments truncated to block, EOS appended)
Unconditional eval FlexMDMEmptyDataset (unconditional_prediction dataloader; empty prompts, max length 1,024)

Training runs for up to 1M steps with validation every 50k steps; checkpoints are saved every 2,500 steps (every 100k steps kept permanently).

Training

W&B run: owt_flexmdm (owt_flexmdm)

xlm job_name=owt_flexmdm job_type=train experiment=owt_flexmdm \
  per_device_batch_size=32 trainer_strategy=ddp trainer.devices=8 trainer.num_nodes=1 \
  ++trainer.precision=bf16-mixed compile=False \
  +loggers.wandb.resume=allow +loggers.wandb.id=owt_flexmdm

Evaluation

Reference W&B eval run: owt_flexmdm_eval_step-800000_null_0.95_1024 (n55k2mel). Checkpoint is loaded from Hugging Face Hub (dhruveshpatel/flexmdm-owt, revision step-800000). Gen. PPL uses experiment=[owt_flexmdm,gpt2_generative_perplexity] (no MAUVE post-hoc eval).

Single eval (set HUB_REVISION, TOP_P, and MAX_STEPS / sampling budget \(T\); confidence=null matches the logged sweep):

HUB_REVISION=step-800000
TOP_P=0.95
MAX_STEPS=1024
CHECKPOINT_TAG="${HUB_REVISION#step-}"

xlm job_name=owt_flexmdm_eval_${HUB_REVISION}_null_${TOP_P}_${MAX_STEPS} \
  job_type=eval experiment=[owt_flexmdm,gpt2_generative_perplexity] \
  ++eval.checkpoint_path=None ++eval.split=validation \
  per_device_batch_size=16 per_device_val_batch_size=16 global_batch_size=16 \
  trainer_strategy=single_device ++trainer.precision=32-true compile=false \
  loggers=wandb +loggers.wandb.resume=allow +loggers.wandb.id=null \
  ~datamodule.dataset_managers.val.lm \
  +hub.repo_id=dhruveshpatel/flexmdm-owt +hub.revision=${HUB_REVISION} \
  ++predictor.confidence=null ++predictor.top_k=null ++predictor.top_p=${TOP_P} \
  ++predictor.max_steps=${MAX_STEPS} \
  +tags.checkpoint=${CHECKPOINT_TAG} \
  paths.log_dir=logs/eval

Reproduce the Results table below (\(p=0.95\), \(T \in \{128,256,512,1024\}\), checkpoint step-800000):

HUB_REVISION=step-800000
TOP_P=0.95
CHECKPOINT_TAG="${HUB_REVISION#step-}"

for MAX_STEPS in 128 256 512 1024; do
  xlm job_name=owt_flexmdm_eval_${HUB_REVISION}_null_${TOP_P}_${MAX_STEPS} \
    job_type=eval experiment=[owt_flexmdm,gpt2_generative_perplexity] \
    ++eval.checkpoint_path=None ++eval.split=validation \
    per_device_batch_size=16 per_device_val_batch_size=16 global_batch_size=16 \
    trainer_strategy=single_device ++trainer.precision=32-true compile=false \
    loggers=wandb +loggers.wandb.resume=allow +loggers.wandb.id=null \
    ~datamodule.dataset_managers.val.lm \
    +hub.repo_id=dhruveshpatel/flexmdm-owt +hub.revision=${HUB_REVISION} \
    ++predictor.confidence=null ++predictor.top_k=null ++predictor.top_p=${TOP_P} \
    ++predictor.max_steps=${MAX_STEPS} \
    +tags.checkpoint=${CHECKPOINT_TAG} \
    paths.log_dir=logs/eval
done

Results

Unconditional generation metrics for FlexMDM (variable-length masked diffusion) checkpoints. Gen. PPL is with respect to GPT-2 Large; entropy is the vocabulary entropy of the generated token distribution. Evaluated with nucleus sampling (\(p=0.95\)) for predictor budgets \(T \in \{128, 256, 512, 1024\}\) on 1,000 samples up to 1,024 tokens. MAUVE was not run for this baseline.

Checkpoint Gen. PPL (↓) Entropy (↑) MAUVE (↑)
128 256 512 1024 128 256 512 1024 128 256 512 1024
800k 64.68 62.08 59.61 59.27 4.93 4.88 4.88 4.92

Source: W&B project ilm-extensions/flexmdm (tags.checkpoint=800000, top_p=0.95).