Skip to content

mdlm

MDLM - Masked Diffusion Language Model for XLM Framework

This package implements the MDLM model with all necessary components: - Model architecture (model_mdlm.py) - Loss function (loss_mdlm.py) - Predictor for inference (predictor_mdlm.py) - Data module (datamodule_mdlm.py) - Metrics computation (metrics_mdlm.py) - Type definitions (types_mdlm.py) - Noise functions (noise_mdlm.py)

This model was migrated from xlm.lm.mdlm to be an external model.

MDLMPredictor

Bases: Module, Predictor[MDLMBatch, MDLMPredictionDict]

Base predictor for MLM. Stochastically selects positions to unmask based on max_steps and max_new_tokens.

__init__(max_steps, max_new_tokens=None, tokenizer=None, model=None, noise_schedule=None, top_k=None, top_p=None)

Initialize MDLM Predictor.

Parameters:

Name Type Description Default
max_steps int

Maximum number of prediction steps.

required
tokenizer Optional[Tokenizer]

Tokenizer for encoding/decoding.

None
noise_schedule Optional[NoiseSchedule]

Noise schedule for the diffusion process.

None
top_k Optional[int]

Top-k sampling parameter.

None
top_p Optional[float]

Top-p sampling parameter.

None
model Optional[MDLMModel]

The MDLM model to use for predictions.

None

decode(results)

Parameters:

Name Type Description Default
results MDLMStepResults

x: Integer[TT, " batch seq_len"] Current predicted sequence.

required

Returns: out: List[str] Decoded sequence with special tokens. x: Integer[TT, " batch seq_len"] Current predicted sequence.

DefaultMDLMCollator

Bases: Collator

Used for MDLM pre-training with padded-truncated sequences.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: The input for the model with masks.
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: The target ids to the model where the input if copied as is and masks are replaced with the correct token.
Padding
  • Padding is done on the right.

MDLMSeq2SeqPredCollator

Bases: Collator

Input contains only the prefix and target_ids contain only the suffix if present.

How is this different from MDLMSeq2SeqTrainCollator? - MDLMSeq2SeqTrainCollator's input_ids contain the joined sequence and target_ids also contain the target for the whole sequence. But MDLMSeq2SeqPredCollator's input_ids contain only the prefix and target_ids contain only the suffix if present.

Batch
  1. input_ids: Integer[TT, " batch seq_len"]: Input contains only the prefix
  2. attention_mask: Integer[TT, " batch seq_len"]: 1 for tokens that are not padding.
  3. target_ids: Integer[TT, " batch seq_len"]: Target contains only the suffix if present.
  4. noise_rate: Float[TT, " batch"]: The noise rate for the model.
  5. total_noise: Float[TT, " batch"]: The total noise for the model.
  6. t: Float[TT, " batch"]: The time step for the model.
Padding
  • There is padding on both sides because all prefixes end at the same position.

MDLMBatch

Bases: TypedDict

Input to the MLM. Attributes: input_ids (Integer[TT, " batch seq_len"]): The input ids to the model. attention_mask (Integer[TT, " batch seq_len"]): 1 for tokens that are not padding. target_ids (Optional[Integer[TT, " batch seq_len"]]): The target ids to the model.

MDLMSeq2SeqPredictionBatch

Bases: TypedDict

Input to the MLM for predicting suffix given the prefix.

MDLMUncondtionalPredictionBatch

Input to the MLM for unconditional generation.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model. All masks.

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

MDLMLossDict

Bases: TypedDict

Output of the LossFunction Callable.

Attributes:

Name Type Description
loss Float[Tensor, '']

The total loss value.

MDLMPredictionDict

Bases: TypedDict

Output of the Predictor for MLM.

Attributes:

Name Type Description
loss Optional[Float[Tensor, batch]]

The loss value. Typically None.

text List[str]

The batch of generated text with special tokens.

ids Integer[Tensor, ' batch seq_len']

The batch of generated token_ids.

time_taken List[float]

Time taken for each prediction.

output_start_idx Integer[Tensor, ' batch']

The index of the first token in the output.