MDLM — Masked Diffusion Language Model
1. Overview
mdlm implements a continuous-time absorbing-state discrete diffusion language model. Training samples a per-example time (t \in [\epsilon, 1]), maps it through a noise schedule (\sigma(t)) to a corruption rate, masks tokens with that rate, and trains the model to denoise with a KL-weighted cross-entropy. The backbone is a DDiT-style Transformer with AdaLN time conditioning and rotary positional embeddings.
@misc{sahoo2024simpleeffectivemaskeddiffusion,
title = {Simple and Effective Masked Diffusion Language Models},
author = {Subham Sekhar Sahoo and Marianne Arriola and Yair Schiff and Aaron Gokaslan and Edgar Marroquin and Justin T Chiu and Alexander Rush and Volodymyr Kuleshov},
year = {2024},
eprint = {2406.07524},
archivePrefix = {arXiv}
}
See xlm-models/mdlm/README.md.
2. Files at a glance
| Module | Public classes / helpers |
|---|---|
| model_mdlm.py | BaseMDLMModel, MDLMModel |
| loss_mdlm.py | MDLMLoss |
| predictor_mdlm.py | MDLMPredictor |
| datamodule_mdlm.py | DefaultMDLMCollator, MDLMSeq2SeqTrainCollator, MDLMSeq2SeqPredCollator, MDLMEmptyDataset, mdlm_single_segment_collate_fn |
| noise_mdlm.py | ContinousTimeNoiseSchedule, ContinuousTimeLinearSchedule, ContinuousTimeLogLinearSchedule, _convert_to_correlated |
| metrics_mdlm.py | seq2seq_exact_match_update_fn, seq2seq_token_accuracy_update_fn, mean_metric_update_fn |
| types_mdlm.py | MDLMBatch, MDLMSeq2SeqPredictionBatch, MDLMLossDict, MDLMModel (Protocol), MDLMPredictionDict |
3. Architecture
MDLMModel(num_embeddings, d_model, num_layers, nhead, ...) wraps a DDiTLayerList (DDiTLayer blocks with AdaLN time conditioning) around a TimestepEmbedder and projects through DDitFinalLayer. The forward signature differs from MLM in that it takes a per-example total-noise value, encoded as the AdaLN condition vector:
forward(
x_t: Integer[TT, " *batch seq_len"],
noise: Float[TT, " *batch"], # ``total_noise`` (passed positionally)
attention_mask: Optional[Bool[TT, " *batch seq_len"]] = None,
positions: Optional[Integer[TT, " *batch seq_len"]] = None,
) -> Float[TT, " *batch seq_len vocab_size"]
noiseis the conditioning signal; internallyc = SiLU(TimestepEmbedder(noise)).positionsare derived fromattention_mask.cumsum(dim=1) - 1whenNone.attention_maskis cast toboolinternally.
The Protocol MDLMModel in types_mdlm.py uses total_noise: Float[TT, " batch"] as the second argument; MDLMLoss and MDLMPredictor both pass total_noise here.
4. Batch contract
MDLMBatch (types_mdlm.py):
| Field | Shape | Notes |
|---|---|---|
input_ids |
(B, L) int |
Random fraction of tokens replaced by [MASK]. |
attention_mask |
(B, L) int |
1 for real tokens, 0 for padding. |
target_ids |
(B, L) int |
Original tokens; -100 at ignored positions when loss_on_padding=False. |
noise_rate |
(B,) float |
( \sigma(t) ) — output of noise_schedule.noise_rate(t). |
total_noise |
(B,) float |
( \bar\sigma(t) ) — output of noise_schedule.total_noise(t). |
t |
(B,) float |
The sampled time. |
noise_rate, total_noise, and t are produced by DefaultMDLMCollator from the wired NoiseSchedule (so the collator requires a real schedule, not DummyNoiseSchedule).
5. Loss
MDLMLoss(loss_on_padding=False, loss_on_visible_tokens=False, model, tokenizer):
configure(pl_module)cachesmask_token_id_tensoron the right device.loss_fn:- Derives
positions = (attention_mask.cumsum(dim=1) - 1).clamp(min=0)and zeroes them out at padding. - Runs
logits = model(input_ids, total_noise, attention_mask, positions). - Builds
ignore = (input_ids != mask_token_id)whenloss_on_visible_tokens=False(default). ce = cross_entropy(logits_T, targets, reduction="none", ignore_index=-100).- Diffusion weight:
weight = noise_rate / torch.expm1(total_noise);kl = ce * weight[:, None]. - Returns
loss = masked_mean(kl.flatten(), ~ignore.flatten()).
6. Collators
The internal helper mdlm_single_segment_collate_fn(examples, noise_schedule, pad_token_id, mask_token_id, ...) samples t = noise_schedule.sample_t(batch_size) and uses noise_schedule(t) to compute (noise_rate, total_noise); masks each example with rate 1 - exp(-total_noise).
| Class | Input | Output batch | Special behavior |
|---|---|---|---|
DefaultMDLMCollator |
BaseCollatorInput |
MDLMBatch |
Pad-right to block_size, BOS/EOS optional. Requires a real NoiseSchedule. |
MDLMSeq2SeqTrainCollator |
Seq2SeqCollatorInput |
MDLMBatch |
Concatenates [prompt][BOS][target][EOS] with right padding; masks only suffix positions. |
MDLMSeq2SeqPredCollator |
Seq2SeqCollatorInput |
MDLMBatch |
input_ids = left-padded prompt only; target_ids = right-padded target (used for seq2seq prediction). |
Noise schedules live in noise_mdlm.py:
ContinuousTimeLinearSchedule(sigma_min, sigma_max)— affine (\bar\sigma(t)) (with the exponentialtotal_noise);t_from_noise_rateraisesRuntimeError.ContinuousTimeLogLinearSchedule(sigma_min, sigma_max)— log-linear total-noise; botht_from_noise_rateandt_from_total_noiseare exact inverses. Requiressigma_min == 0.0(raisesNotImplementedErrorotherwise).- Both support
antithetic_sampling=True(default) — samplet ~ U[0,1]then spread via_convert_to_correlated(t / B + arange(B) / B). grad=Trueandimportance_sampling=Trueare not implemented and raise.
7. Predictor
MDLMPredictor(max_steps, max_new_tokens=None, tokenizer, model, noise_schedule, top_k=None, top_p=None):
- Sampling function selected at
__init__: top-k ->sample_from_top_k, top-p ->sample_from_top_p, both ->sample_categorical, neither ->ValueError. predict()clonesinput_ids, optionally appendsmax_new_tokens[MASK]tokens, derives positions fromattention_mask.cumsum-1, and starts the diffusion chain att = 1.predict_single_step:s = t - dtwheredt = (1 - 1e-5) / (max_steps + 1).dot_sigma_t, dot_sigma_s = noise_schedule(t)[1], noise_schedule(s)[1](i.e.total_noise).chance_t = 1 - exp(-dot_sigma_t),chance_s = 1 - exp(-dot_sigma_s).- For non-final steps, builds the categorical (q(x_s | x_t)):
softmax(logits) * (chance_t - chance_s)with the mask token bucket set tochance_s, then samplesx_s. - For the final step,
argmax(logits). - Non-mask positions in
x_tare preserved viatorch.where(masked, x_s, x_t). stop()true whendone.all()ort <= 0everywhere or no[MASK]remains.- Output
MDLMPredictionDict:{text, ids, loss=None, time_taken, output_start_idx}.
8. Metrics
See tests/models/mdlm/test_metrics_mdlm.py.
| Function | Returned keys |
|---|---|
seq2seq_exact_match_update_fn |
pred = loss_dict["ids"][:, output_start_idx:], target, pred_length = pred.shape[-1], target_length |
seq2seq_token_accuracy_update_fn |
pred, target, pred_mask = ones_like(pred) |
mean_metric_update_fn |
value = loss_dict["loss"] |
9. Configs / experiments
Hydra groups under xlm-models/mdlm/configs/. Available experiment entry points:
experiment=owt_mdlm(OpenWebText)experiment=text_mdlm
10. Testing
Tests in tests/models/mdlm/:
test_model_mdlm.py— extendsBaseModelTests, plus a positions-from-mask check (added in this plan).test_loss_mdlm.py— extendsBaseLossTests.test_collator_mdlm.py— now usesreal_loglinear_schedule(added in this plan) to exerciseDefaultMDLMCollator.test_predictor_mdlm.py— now usesreal_loglinear_scheduleto exerciseMDLMPredictor.predict().test_noise_mdlm.py,test_metrics_mdlm.py— pure-function helpers.
Shared fixtures (tiny_mdlm_model, mdlm_batch, simple_tokenizer, real_loglinear_schedule) live in tests/conftest.py and tests/models/mdlm/conftest.py.