xlm.datamodule
BaseBatch
Bases: TypedDict
Dict with the keys that are present in input batches for all models.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
Integer[Tensor, ' batch seq_len']
|
The input ids to the model. |
attention_mask |
Integer[Tensor, ' batch seq_len']
|
1 for tokens that are not padding. |
token_type_ids |
Integer[Tensor, ' batch seq_len']
|
Can depend on the model type. For ILM and IDLM: 0 for CLS, 1 for BOS and prefix, 2 for other tokens. |
SimpleSpaceTokenizer
Bases: PreTrainedTokenizer
Splits on spaces
__init__(vocab, **kwargs)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab
|
Sequence[str]
|
List of desired tokens. Following are list of all of the special tokens with their corresponding ids: "[PAD]": 0, "[UNK]": 1, "[MASK]": 2, "[EOS]": 3, "[BOS]": 4, an id (starting at 5) will be assigned to each character. |
required |
model_max_length
|
int
|
Model maximum sequence length. |
required |
DatasetManager
Manages a single dataset through its lifecycle: download, preprocess, cache, setup, and dataloader creation.
Used per split (train/val/test/predict) by :class:TextDataModule, which orchestrates
multiple DatasetManager instances.
__init__(collator, full_name, full_name_debug, dataloader_kwargs, preprocess_function=None, preprocess_function_kwargs=None, on_the_fly_processor=None, on_the_fly_processor_kwargs=None, on_the_fly_group_processor=None, model_name=None, columns_to_remove=None, stages=None, iterable_dataset_shards=None, shuffle_buffer_size=None, shuffle_seed=42, split_by_node=True, rewrite_manual_cache=False, use_manual_cache=True)
Initialize the dataset manager.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
collator
|
Collator
|
Collates examples into batches; model-specific. |
required |
full_name
|
str
|
Full dataset path (e.g., "repo/ds_name/split"). |
required |
full_name_debug
|
str
|
Used when DEBUG_OVERFIT is True; typically the train split path so val/test managers overfit on train data. |
required |
dataloader_kwargs
|
DataLoaderKwargs
|
batch_size, num_workers, shuffle, pin_memory. |
required |
preprocess_function
|
Optional[str]
|
Dotted path to preprocessing function (e.g., tokenization). |
None
|
preprocess_function_kwargs
|
Optional[Dict[str, Any]]
|
Kwargs passed to the preprocess function. |
None
|
on_the_fly_processor
|
Optional[str]
|
Dotted path to per-example processor (e.g., ids_to_example_fn). For iterable (streaming) datasets, this is applied to each example on the fly, and before the group processor. |
None
|
on_the_fly_processor_kwargs
|
Optional[Dict[str, Any]]
|
Kwargs for the on-the-fly processor. |
None
|
on_the_fly_group_processor
|
Optional[str]
|
Dotted path to the group processor function, which receives large batches of examples, and applies on-the-fly processors that require a big chunk of data, for example, packing sequences without padding. |
None
|
model_name
|
Optional[str]
|
Used for model-specific cache subdirectories. |
None
|
columns_to_remove
|
Optional[List[str]]
|
Columns to drop during preprocessing (e.g., ["text"]). |
None
|
stages
|
Optional[List[Literal['fit', 'validate', 'test', 'predict']]]
|
Lightning stages this manager participates in. |
None
|
iterable_dataset_shards
|
Optional[int]
|
If set, dataset is converted to IterableDataset with this many shards. |
None
|
shuffle_buffer_size
|
Optional[int]
|
Buffer size for dataset.shuffle() (IterableDataset). |
None
|
shuffle_seed
|
Optional[int]
|
Seed for shuffle. |
42
|
split_by_node
|
bool
|
Whether to split IterableDataset by rank in DDP. |
True
|
rewrite_manual_cache
|
bool
|
If True, re-download and overwrite cached data. |
False
|
use_manual_cache
|
bool
|
If True, use manual cache dir created using save_to_disk(), else let HF Datasets do automatic caching. |
True
|
set_epoch(epoch)
Set epoch for IterableDataset shuffle buffer reproducibility.
For IterableDataset, calls dataset.set_epoch(epoch). For map-style datasets with DistributedSampler, set_epoch must be called on the sampler separately.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epoch
|
int
|
Current training epoch. |
required |
Warning
Note for future extensions. For map-style datasets with DistributedSampler, set_epoch must be called on the sampler separately.
prepare_data(manual_cache_dir, tokenizer, num_proc=None, load=False)
Download, preprocess, and cache the dataset.
If use_manual_cache: checks cache first; downloads and caches (using save_to_disk()) if missing. Note this different from HF datasets' automatic caching. If rewrite_manual_cache: re-downloads and overwrites cache. Called before setup() by TextDataModule.prepare_data().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manual_cache_dir
|
str
|
Base directory for manual cache. |
required |
tokenizer
|
Tokenizer
|
Tokenizer for preprocessing. |
required |
num_proc
|
Optional[int]
|
Number of processes for parallel map operations. |
None
|
load
|
bool
|
If True and cache exists, load and return the dataset; else return None. |
False
|
Returns:
| Type | Description |
|---|---|
Optional[Dataset]
|
The dataset if load=True and cache was found/created; None otherwise. |
Note
This method is only called on rank 0 and therefore does not set an state on the instance. setup() is called on all ranks.
setup(stage, manual_cache_dir, tokenizer, block_size, is_ddp, rank, world_size, num_dataset_workers=None)
Load dataset (should have already been downloaded and preprocessed by prepare_data()), optionally convert to IterableDataset, apply on-the-fly and group processors, and split by node if DDP.
calledBy
TextDataModule.setup().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stage
|
Literal['fit', 'validate', 'test', 'predict']
|
Lightning stage (fit, validate, test, predict). |
required |
manual_cache_dir
|
str
|
Base directory for manual cache. |
required |
tokenizer
|
Tokenizer
|
Tokenizer for processors. |
required |
block_size
|
int
|
Max sequence length for group processors. |
required |
is_ddp
|
bool
|
Whether distributed training is used. |
required |
rank
|
int
|
Global rank of this process. |
required |
world_size
|
int
|
Total number of processes. |
required |
num_dataset_workers
|
Optional[int]
|
Number of workers for prepare_data if cache not used. |
None
|
get_dataloader(type, is_ddp, rank, world_size)
Return a DataLoader or StatefulDataLoader for the given split.
Chooses loader and sampler based on: DDP vs single-GPU, IterableDataset vs map-style. Train: StatefulDataLoader with appropriate sampler. Val/test/predict: standard DataLoader.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
type
|
Literal['train', 'val', 'test', 'predict']
|
Dataloader type (train, val, test, predict). |
required |
is_ddp
|
bool
|
Whether distributed training is used. |
required |
rank
|
int
|
Global rank of this process. |
required |
world_size
|
int
|
Total number of processes. |
required |
Returns:
| Type | Description |
|---|---|
Union[DataLoader, StatefulDataLoader]
|
Configured DataLoader or StatefulDataLoader. |
UnconditionalGenerationDatasetManager
This is used for unconditional generation, where we don't have any input text.
BaseDataModule
Bases: LightningDataModule
Base class for all datamodules.
tokenizer
instance-attribute
The tokenizer.[Required]
print_batch(batch, split, dataloader_idx=None)
Required to print train and validation batches at the beginning of the epoch.
BaseCollatorInput
Bases: TypedDict
Dict with values that are lists of raw input_ids of variable length.
This is the input to the collator for pre-training.
The elements of the lists can be of different lengths.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
List[int]
|
The input ids. |
Seq2SeqCollatorInput
Bases: TypedDict
Dict with values that are lists of raw input_ids, attention_mask, and token_type_ids.
This is the input to the collator for pre-training.
The elements of the lists can be of different lengths.
Attributes:
| Name | Type | Description |
|---|---|---|
input_ids |
List[int]
|
The input ids. |
prompt_ids |
List[int]
|
The target ids. |
DefaultCollator
Bases: Collator
Simply stacks the input_ids, attention_mask, and token_type_ids and returns a batch.
DefaultCollatorWithPadding
Bases: DefaultCollator
Like DefaultCollator, but pads (truncates if needed) the input_ids, attention_mask, and token_type_ids to self.max_length.