Skip to content

xlm.datamodule

BaseBatch

Bases: TypedDict

Dict with the keys that are present in input batches for all models.

Attributes:

Name Type Description
input_ids Integer[Tensor, ' batch seq_len']

The input ids to the model.

attention_mask Integer[Tensor, ' batch seq_len']

1 for tokens that are not padding.

token_type_ids Integer[Tensor, ' batch seq_len']

Can depend on the model type. For ILM and IDLM: 0 for CLS, 1 for BOS and prefix, 2 for other tokens.

SimpleSpaceTokenizer

Bases: PreTrainedTokenizer

Splits on spaces

__init__(vocab, **kwargs)

Parameters:

Name Type Description Default
vocab Sequence[str]

List of desired tokens. Following are list of all of the special tokens with their corresponding ids: "[PAD]": 0, "[UNK]": 1, "[MASK]": 2, "[EOS]": 3, "[BOS]": 4, an id (starting at 5) will be assigned to each character.

required
model_max_length int

Model maximum sequence length.

required

DatasetManager

Manages a single dataset through its lifecycle: download, preprocess, cache, setup, and dataloader creation.

Used per split (train/val/test/predict) by :class:TextDataModule, which orchestrates multiple DatasetManager instances.

__init__(collator, full_name, full_name_debug, dataloader_kwargs, preprocess_function=None, preprocess_function_kwargs=None, on_the_fly_processor=None, on_the_fly_processor_kwargs=None, on_the_fly_group_processor=None, model_name=None, columns_to_remove=None, stages=None, iterable_dataset_shards=None, shuffle_buffer_size=None, shuffle_seed=42, split_by_node=True, rewrite_manual_cache=False, use_manual_cache=True)

Initialize the dataset manager.

Parameters:

Name Type Description Default
collator Collator

Collates examples into batches; model-specific.

required
full_name str

Full dataset path (e.g., "repo/ds_name/split").

required
full_name_debug str

Used when DEBUG_OVERFIT is True; typically the train split path so val/test managers overfit on train data.

required
dataloader_kwargs DataLoaderKwargs

batch_size, num_workers, shuffle, pin_memory.

required
preprocess_function Optional[str]

Dotted path to preprocessing function (e.g., tokenization).

None
preprocess_function_kwargs Optional[Dict[str, Any]]

Kwargs passed to the preprocess function.

None
on_the_fly_processor Optional[str]

Dotted path to per-example processor (e.g., ids_to_example_fn). For iterable (streaming) datasets, this is applied to each example on the fly, and before the group processor.

None
on_the_fly_processor_kwargs Optional[Dict[str, Any]]

Kwargs for the on-the-fly processor.

None
on_the_fly_group_processor Optional[str]

Dotted path to the group processor function, which receives large batches of examples, and applies on-the-fly processors that require a big chunk of data, for example, packing sequences without padding.

None
model_name Optional[str]

Used for model-specific cache subdirectories.

None
columns_to_remove Optional[List[str]]

Columns to drop during preprocessing (e.g., ["text"]).

None
stages Optional[List[Literal['fit', 'validate', 'test', 'predict']]]

Lightning stages this manager participates in.

None
iterable_dataset_shards Optional[int]

If set, dataset is converted to IterableDataset with this many shards.

None
shuffle_buffer_size Optional[int]

Buffer size for dataset.shuffle() (IterableDataset).

None
shuffle_seed Optional[int]

Seed for shuffle.

42
split_by_node bool

Whether to split IterableDataset by rank in DDP.

True
rewrite_manual_cache bool

If True, re-download and overwrite cached data.

False
use_manual_cache bool

If True, use manual cache dir created using save_to_disk(), else let HF Datasets do automatic caching.

True

set_epoch(epoch)

Set epoch for IterableDataset shuffle buffer reproducibility.

For IterableDataset, calls dataset.set_epoch(epoch). For map-style datasets with DistributedSampler, set_epoch must be called on the sampler separately.

Parameters:

Name Type Description Default
epoch int

Current training epoch.

required
Warning

Note for future extensions. For map-style datasets with DistributedSampler, set_epoch must be called on the sampler separately.

prepare_data(manual_cache_dir, tokenizer, num_proc=None, load=False)

Download, preprocess, and cache the dataset.

If use_manual_cache: checks cache first; downloads and caches (using save_to_disk()) if missing. Note this different from HF datasets' automatic caching. If rewrite_manual_cache: re-downloads and overwrites cache. Called before setup() by TextDataModule.prepare_data().

Parameters:

Name Type Description Default
manual_cache_dir str

Base directory for manual cache.

required
tokenizer Tokenizer

Tokenizer for preprocessing.

required
num_proc Optional[int]

Number of processes for parallel map operations.

None
load bool

If True and cache exists, load and return the dataset; else return None.

False

Returns:

Type Description
Optional[Dataset]

The dataset if load=True and cache was found/created; None otherwise.

Note

This method is only called on rank 0 and therefore does not set an state on the instance. setup() is called on all ranks.

setup(stage, manual_cache_dir, tokenizer, block_size, is_ddp, rank, world_size, num_dataset_workers=None)

Load dataset (should have already been downloaded and preprocessed by prepare_data()), optionally convert to IterableDataset, apply on-the-fly and group processors, and split by node if DDP.

calledBy

TextDataModule.setup().

Parameters:

Name Type Description Default
stage Literal['fit', 'validate', 'test', 'predict']

Lightning stage (fit, validate, test, predict).

required
manual_cache_dir str

Base directory for manual cache.

required
tokenizer Tokenizer

Tokenizer for processors.

required
block_size int

Max sequence length for group processors.

required
is_ddp bool

Whether distributed training is used.

required
rank int

Global rank of this process.

required
world_size int

Total number of processes.

required
num_dataset_workers Optional[int]

Number of workers for prepare_data if cache not used.

None

get_dataloader(type, is_ddp, rank, world_size)

Return a DataLoader or StatefulDataLoader for the given split.

Chooses loader and sampler based on: DDP vs single-GPU, IterableDataset vs map-style. Train: StatefulDataLoader with appropriate sampler. Val/test/predict: standard DataLoader.

Parameters:

Name Type Description Default
type Literal['train', 'val', 'test', 'predict']

Dataloader type (train, val, test, predict).

required
is_ddp bool

Whether distributed training is used.

required
rank int

Global rank of this process.

required
world_size int

Total number of processes.

required

Returns:

Type Description
Union[DataLoader, StatefulDataLoader]

Configured DataLoader or StatefulDataLoader.

UnconditionalGenerationDatasetManager

This is used for unconditional generation, where we don't have any input text.

BaseDataModule

Bases: LightningDataModule

Base class for all datamodules.

tokenizer instance-attribute

The tokenizer.[Required]

print_batch(batch, split, dataloader_idx=None)

Required to print train and validation batches at the beginning of the epoch.

BaseCollatorInput

Bases: TypedDict

Dict with values that are lists of raw input_ids of variable length.

This is the input to the collator for pre-training.

The elements of the lists can be of different lengths.

Attributes:

Name Type Description
input_ids List[int]

The input ids.

Seq2SeqCollatorInput

Bases: TypedDict

Dict with values that are lists of raw input_ids, attention_mask, and token_type_ids.

This is the input to the collator for pre-training.

The elements of the lists can be of different lengths.

Attributes:

Name Type Description
input_ids List[int]

The input ids.

prompt_ids List[int]

The target ids.

DefaultCollator

Bases: Collator

Simply stacks the input_ids, attention_mask, and token_type_ids and returns a batch.

DefaultCollatorWithPadding

Bases: DefaultCollator

Like DefaultCollator, but pads (truncates if needed) the input_ids, attention_mask, and token_type_ids to self.max_length.

DefaultCollatorWithDynamicPadding

Bases: DefaultCollatorWithPadding

Like DefaultCollator, but pads to the max length in the batch.