Skip to content

xlm.datamodule

BaseBatch

Bases: TypedDict

Dict with the keys that are present in input batches for all models.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model.

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

token_type_ids Integer[Tensor, ' batch seq_len']

Can depend on the model type. For ILM : 0 for CLS, 1 for BOS and prefix, 2 for other tokens.

GPT2TokenizerWithCyclicPads

Bases: GPT2Tokenizer

GPT2Tokenizer with cyclic pad tokens (pad_0..pad_{n-1}).

GPT2TokenizerFastWithCyclicPads

Bases: GPT2TokenizerFast

GPT2TokenizerFast with cyclic pad tokens (pad_0..pad_{n-1}).

BertTokenizerWithCyclicPads

Bases: BertTokenizer

BertTokenizer with cyclic pad tokens (pad_0..pad_{n-1}).

BertTokenizerFastWithCyclicPads

Bases: BertTokenizerFast

BertTokenizerFast with cyclic pad tokens (pad_0..pad_{n-1}).

SimpleSpaceTokenizer

Bases: PreTrainedTokenizer

Splits on spaces

__init__(vocab, **kwargs)

Parameters:

Name Type Description Default
vocab Sequence[str]

List of desired tokens. Following are list of all of the special tokens with their corresponding ids: "[PAD]": 0, "[UNK]": 1, "[MASK]": 2, "[EOS]": 3, "[BOS]": 4, an id (starting at 5) will be assigned to each character.

required
model_max_length int

Model maximum sequence length.

required

SimpleSpaceTokenizerWithCyclicPads

Bases: SimpleSpaceTokenizer

SimpleSpaceTokenizer with cyclic pad tokens (pad_0..pad_{n-1}).

DatasetManager

Manages a single dataset through its lifecycle: download, preprocess, cache, setup, and dataloader creation.

Used per split (train/val/test/predict) by :class:TextDataModule, which orchestrates multiple DatasetManager instances.

__init__(collator, full_name, full_name_debug, dataloader_kwargs, filter_fn=None, filter_suffix=None, preprocess_function=None, preprocess_function_kwargs=None, on_the_fly_filter_fn=None, on_the_fly_processor=None, on_the_fly_processor_kwargs=None, on_the_fly_group_processor=None, on_the_fly_group_processor_kwargs=None, on_the_fly_group_processor_remove_columns=None, model_name=None, columns_to_remove=None, columns_to_keep=None, stages=None, iterable_dataset_shards=None, shuffle_buffer_size=None, shuffle_seed=42, split_by_node=True, rewrite_manual_cache=False, use_manual_cache=True, train_test_split=None, make_infinite=False)

Initialize the dataset manager.

Parameters:

Name Type Description Default
collator Collator

Collates examples into batches; model-specific.

required
full_name str

Full dataset path (e.g., "repo/ds_name/split").

required
full_name_debug str

Used when DEBUG_OVERFIT is True; typically the train split path so val/test managers overfit on train data.

required
dataloader_kwargs DataLoaderKwargs

batch_size, num_workers, shuffle, pin_memory.

required
preprocess_function Optional[str]

Dotted path to preprocessing function (e.g., tokenization).

None
preprocess_function_kwargs Optional[Dict[str, Any]]

Kwargs passed to the preprocess function.

None
on_the_fly_processor Optional[str]

Dotted path to per-example processor (e.g., ids_to_example_fn). For iterable (streaming) datasets, this is applied to each example on the fly, and before the group processor.

None
on_the_fly_processor_kwargs Optional[Dict[str, Any]]

Kwargs for the on-the-fly processor.

None
on_the_fly_group_processor Optional[str]

Dotted path to the group processor function, which receives large batches of examples, and applies on-the-fly processors that require a big chunk of data, for example, packing sequences without padding.

None
on_the_fly_group_processor_kwargs Optional[Dict[str, Any]]

Kwargs for the group processor.

None
on_the_fly_group_processor_remove_columns Optional[List[str]]

Passed to dataset.map(..., remove_columns=...) after the group step. When the processor changes the number of examples per batch (e.g. :func:pack_sequences), input columns such as token_ids must be dropped or HF merges them with outputs and column lengths disagree (IndexError in _batch_to_examples). Omit or use [] for 1:1 batch transforms.

None
model_name Optional[str]

Used for model-specific cache subdirectories.

None
columns_to_remove Optional[List[str]]

Columns to drop during preprocessing (e.g., ["text"]).

None
stages Optional[List[Literal['fit', 'validate', 'test', 'predict']]]

Lightning stages this manager participates in.

None
iterable_dataset_shards Optional[int]

If set, dataset is converted to IterableDataset with this many shards.

None
shuffle_buffer_size Optional[int]

Buffer size for dataset.shuffle() (IterableDataset).

None
shuffle_seed Optional[int]

Seed for shuffle.

42
split_by_node bool

Whether to split IterableDataset by rank in DDP.

True
rewrite_manual_cache bool

If True, re-download and overwrite cached data.

False
use_manual_cache bool

If True, use manual cache dir created using save_to_disk(), else let HF Datasets do automatic caching.

True
train_test_split Optional[TrainTestSplitConfig]

If set, split the loaded split into train/test via Dataset.train_test_split. Keys: size (float, passed as test_size) and optional seed (int, default 42).

None

set_epoch(epoch)

Set epoch for IterableDataset shuffle buffer reproducibility.

For IterableDataset, calls dataset.set_epoch(epoch). For map-style datasets with DistributedSampler, set_epoch must be called on the sampler separately.

Parameters:

Name Type Description Default
epoch int

Current training epoch.

required
Warning

Note for future extensions. For map-style datasets with DistributedSampler, set_epoch must be called on the sampler separately.

prepare_data(manual_cache_dir, tokenizer, num_proc=None, load=False)

Download, preprocess, and cache the dataset.

If use_manual_cache: checks cache first; downloads and caches (using save_to_disk()) if missing. Note this different from HF datasets' automatic caching. If rewrite_manual_cache: re-downloads and overwrites cache. Called before setup() by TextDataModule.prepare_data().

Parameters:

Name Type Description Default
manual_cache_dir str

Base directory for manual cache.

required
tokenizer Tokenizer

Tokenizer for preprocessing.

required
num_proc Optional[int]

Number of processes for parallel map operations.

None
load bool

If True and cache exists, load and return the dataset; else return None.

False

Returns:

Type Description
Optional[Dataset]

The dataset if load=True and cache was found/created; None otherwise.

Note

This method is only called on rank 0 and therefore does not set an state on the instance. setup() is called on all ranks.

setup(stage, manual_cache_dir, tokenizer, block_size, is_ddp, rank, world_size, num_dataset_workers=None)

Load dataset (should have already been downloaded and preprocessed by prepare_data()), optionally convert to IterableDataset, apply on-the-fly and group processors, and split by node if DDP.

calledBy

TextDataModule.setup().

Parameters:

Name Type Description Default
stage Literal['fit', 'validate', 'test', 'predict']

Lightning stage (fit, validate, test, predict).

required
manual_cache_dir str

Base directory for manual cache.

required
tokenizer Tokenizer

Tokenizer for processors.

required
block_size int

Max sequence length for group processors.

required
is_ddp bool

Whether distributed training is used.

required
rank int

Global rank of this process.

required
world_size int

Total number of processes.

required
num_dataset_workers Optional[int]

Number of workers for prepare_data if cache not used.

None

get_dataloader(type, is_ddp, rank, world_size)

Return a DataLoader or StatefulDataLoader for the given split.

Chooses loader and sampler based on: DDP vs single-GPU, IterableDataset vs map-style. Train: StatefulDataLoader with appropriate sampler. Val/test/predict: standard DataLoader.

Parameters:

Name Type Description Default
type Literal['train', 'val', 'test', 'predict']

Dataloader type (train, val, test, predict).

required
is_ddp bool

Whether distributed training is used.

required
rank int

Global rank of this process.

required
world_size int

Total number of processes.

required

Returns:

Type Description
Union[DataLoader, StatefulDataLoader]

Configured DataLoader or StatefulDataLoader.

UnconditionalGenerationDatasetManager

This is used for unconditional generation, where we don't have any input text.

BaseDataModule

Bases: LightningDataModule

Base class for all datamodules.

tokenizer instance-attribute

The tokenizer.[Required]

print_batch(batch, split, dataloader_idx=None)

Required to print train and validation batches at the beginning of the epoch.

BaseCollatorInput

Bases: TypedDict

Dict with values that are lists of raw input_ids of variable length.

This is the input to the collator for pre-training.

The elements of the lists can be of different lengths.

Attributes:

Name Type Description
input_ids List[int]

The input ids.

Seq2SeqCollatorInput

Bases: TypedDict

Dict with values that are lists of raw input_ids, attention_mask, and token_type_ids.

This is the input to the collator for pre-training.

The elements of the lists can be of different lengths.

Attributes:

Name Type Description
input_ids List[int]

The input ids.

prompt_ids List[int]

The target ids.

DefaultCollator

Bases: Collator

Simply stacks the input_ids, attention_mask, and token_type_ids and returns a batch.

DefaultCollatorWithPadding

Bases: DefaultCollator

Like DefaultCollator, but pads (truncates if needed) the input_ids, attention_mask, and token_type_ids to self.max_length.

DefaultCollatorWithDynamicPadding

Bases: DefaultCollatorWithPadding

Like DefaultCollator, but pads to the max length in the batch.

add_cyclic_pad_tokens(tokenizer, n, token_template=None)

Add n cyclic pad tokens and set pad_i_token, pad_i_token_id for i in range(n). Modifies tokenizer in-place.

Parameters:

Name Type Description Default
tokenizer PreTrainedTokenizerBase

The tokenizer to modify (BertTokenizer, GPT2Tokenizer, etc.).

required
n int

Number of cyclic pad tokens.

required
token_template Optional[str]

Format string for token names. If None, infers from pad_token: GPT2-style (pad_token like "<|pad|>"): "<|pad_{}|>" Bert-style (pad_token like "[PAD]"): "[PAD_{}]"

None

get_cyclic_pad_token_ids(tokenizer, n)

Return [pad_0_token_id, ..., pad_{n-1}_token_id]. Raises AttributeError if any pad_i_token_id is missing.

pack_sequences(examples, tokenizer, block_size, drop_last=True, use_bos=True, **kwargs)

Pack sequences without padding using EOS as separator and optional BOS.

Concatenates sequences with [BOS] seq [EOS] (or seq [EOS] if use_bos=False), then chunks into blocks of exactly block_size. For the last incomplete block: drop or pad according to drop_last.

Parameters:

Name Type Description Default
examples Dict[str, List[List[int]]]

Batched dict with "token_ids" key (list of token id lists).

required
tokenizer PreTrainedTokenizerBase

Tokenizer for BOS, EOS, PAD ids.

required
block_size int

Target block length.

required
drop_last bool

If True, drop incomplete last block; else pad with pad_token_id.

True
use_bos bool

If True, prepend BOS before each sequence.

True
**kwargs Any

Absorbed for interface compatibility.

{}

Returns:

Type Description
Dict[str, List[List[int]]]

Dict with "input_ids", "attention_mask", "token_type_ids" (zeros).