Data Pipeline
TextDataModule is a Lightning DataModule that orchestrates the full data
lifecycle -- download, preprocessing, caching, creation of dataloaders. It holds
one or more DatasetManager instances per split (train, val, test, predict).
Each manager is registered under a dataloader name (e.g. "lm",
"prediction"), and that same name is used by the Harness to look up the
metric group that should be evaluated on the corresponding dataloader.
Understanding this naming link is the single most important thing for working
with the pipeline.
Architecture Overview
The following diagram shows a typical setup. Each box on the left is a
DatasetManager registered under a dataloader name inside TextDataModule.
The arrows show how that same name is used by the Harness to find the
matching metric group on the right.
flowchart TB
subgraph TDM [TextDataModule]
direction TB
train["train: #quot;lm#quot; -> DatasetManager"]
val_lm["val: #quot;lm#quot; -> DatasetManager"]
val_pred["val: #quot;prediction#quot; -> DatasetManager"]
test_lm["test: #quot;lm#quot; -> DatasetManager"]
test_pred["test: #quot;prediction#quot; -> DatasetManager"]
predict["predict: #quot;unconditional#quot; -> UnconditionalGenerationDM"]
end
subgraph H [Harness]
direction TB
rm_train["reported_metrics.train.lm"]
rm_val_lm["reported_metrics.val.lm"]
rm_val_pred["reported_metrics.val.prediction"]
rm_test_lm["reported_metrics.test.lm"]
rm_test_pred["reported_metrics.test.prediction"]
end
train -- "dataloader_name = #quot;lm#quot;" --> rm_train
val_lm -- "dataloader_name = #quot;lm#quot;" --> rm_val_lm
val_pred -- "dataloader_name = #quot;prediction#quot;" --> rm_val_pred
test_lm -- "dataloader_name = #quot;lm#quot;" --> rm_test_lm
test_pred -- "dataloader_name = #quot;prediction#quot;" --> rm_test_pred
Internally, TextDataModule stores its managers as a nested dictionary:
dataset_managers: Dict[
Literal["train", "val", "test", "predict"],
Dict[str, DatasetManager], # key = dataloader name
]
The train split must contain exactly one DatasetManager; val / test /
predict can each contain zero or more, yielding a list of dataloaders.
On construction, TextDataModule assigns each dataloader name a numeric index
(the iteration order). During a training or evaluation step the Harness
receives a dataloader_idx from the Lightning Trainer and resolves it back to
the dataloader name to find the right metrics:
dataloader_idx (from Lightning Trainer)
|
v
dataloader_names[stage][idx] --> dataloader_name (e.g. "lm")
|
v
reported_metrics["metrics_{stage}"][dataloader_name] --> list of MetricWrapper
Because the same name appears in both the datamodule config and the metrics
config, adding a new evaluation dataset with its own metrics is just a matter of
registering a new DatasetManager under a new name and adding a matching entry
in the metrics config. For details on how metrics themselves are structured, see
the Metrics guide.
The TextDataModule lifecycle follows the standard Lightning pattern:
prepare_data()-- rank 0 only; delegates to each manager'sprepare_data()(download + preprocess + cache).setup(stage)-- all ranks; delegates to each manager'ssetup()(load from cache, apply processors, DDP splitting).{train,val,test,predict}_dataloader()-- returns the configuredDataLoader(s).
The following config example shows how the pieces fit together in practice.
Real configs live under configs/lightning_train/datamodule/ and
configs/lightning_train/model_type/.
Datamodule config (datamodule/lm1b.yaml, simplified):
datamodule:
_target_: xlm.datamodule.TextDataModule
dataset_managers:
train: # <-- split name
lm: # <-- dataloader name
_target_: xlm.datamodule.DatasetManager
full_name: billion-word-benchmark/lm1b/train
preprocess_function: xlm.tasks.lm1b.preprocess_fn
on_the_fly_processor: xlm.datamodule.token_ids_to_input_ids
columns_to_remove: [text]
stages: [fit]
# ...
val: # <-- split name
lm: # <-- same dataloader name
_target_: xlm.datamodule.DatasetManager
full_name: billion-word-benchmark/lm1b/test
stages: [fit, validate]
# ...
unconditional_prediction: # <-- another dataloader name
_target_: xlm.datamodule.UnconditionalGenerationDatasetManager
# ...
Metrics config (model_type/mdlm.yaml, simplified):
reported_metrics:
train:
lm: # <-- must match dataloader name above
accumulated_loss:
prefix: train/lm
update_fn: mdlm.metrics_mdlm.mean_metric_update_fn
val:
lm: # <-- must match
accumulated_loss:
prefix: val/lm
update_fn: mdlm.metrics_mdlm.mean_metric_update_fn
The names lm and unconditional_prediction are the glue: they appear
identically in the datamodule and metrics configs.
DatasetManager Lifecycle
Each DatasetManager owns a single dataset and drives it through a multi-stage
pipeline:
flowchart LR
subgraph prepare_data ["prepare_data() -- rank 0 only"]
direction LR
A[Download] --> B[Preprocess]
B --> C[Cache]
end
subgraph setup ["setup() -- all ranks"]
direction LR
D[Load from cache] --> E["On-the-fly\nProcessors"]
E --> F["Group\nProcessors"]
end
prepare_data --> setup
setup --> G["Dataloader\n+ Collator"]
| Stage | When | What happens |
|---|---|---|
| Download | prepare_data() (rank 0) |
datasets.load_dataset() fetches from HuggingFace Hub. LocalDatasetManager overrides this to load local CSV files instead. |
| Preprocess | prepare_data() (rank 0) |
A configurable function (dotted-path string, e.g. a tokenization function) is applied via dataset.map(). Unwanted columns are dropped. |
| Cache | prepare_data() (rank 0) |
The preprocessed dataset is saved to disk (manual_cache_dir / full_name) with save_to_disk(). Subsequent runs skip download + preprocess. |
| Setup | setup(stage) (all ranks) |
Loads from cache. Optionally converts to IterableDataset (sharded). Applies on-the-fly and group processors. Splits by node in DDP. |
| On-the-fly Processors | Lazy, per-example | Per-example transforms set via dataset.set_transform(). Example: token_ids_to_input_ids converts raw token IDs into input_ids, attention_mask, token_type_ids. |
| Group Processors | Lazy, batched | Operates on large chunks of examples, e.g. sequence packing to block_size without padding. |
| Collation | DataLoader | A Collator converts a list of examples into a BaseBatch tensor dict (input_ids, attention_mask, token_type_ids). |
Dataloader and Sampler Selection
get_dataloader() picks the loader class and sampler based on three axes:
split, DDP, and dataset type.
| Split | DDP | Iterable | Loader | Sampler | Notes |
|---|---|---|---|---|---|
| train | yes | yes | StatefulDataLoader |
None | Shuffling handled by the IterableDataset shuffle buffer; explicit shuffle kwarg is ignored. Shards split across workers during setup(). |
| train | yes | no | StatefulDataLoader |
StatefulDistributedSampler |
Partitions data across ranks and shuffles (disabled under DEBUG_OVERFIT). |
| train | no | no | StatefulDataLoader |
RandomSampler |
Single-GPU. Falls back to SequentialSampler under DEBUG_OVERFIT. |
| train | no | yes | StatefulDataLoader |
None | Single-GPU with iterable dataset; no sampler needed. |
| val / test / predict | any | any | DataLoader |
None | Never shuffled. Standard (non-stateful) DataLoader. |
All train dataloaders use StatefulDataLoader (from torchdata) so that
iteration state can be checkpointed and resumed mid-epoch.