Skip to content

xlm.log_predictions

FilePredictionWriter

Bases: _PredictionWriter

Writer that outputs predictions to JSONL files.

__init__(fields_to_keep_in_output=None, file_path_='from_pl_module')

Initialize FilePredictionWriter.

Parameters:

Name Type Description Default
fields_to_keep_in_output Optional[List[str]]

List of fields to keep in the output. If None, all fields are kept.

None
file_path_ Union[Path, Literal['none', 'from_pl_module']]

Path to the file or special values. if "from_pl_module", query the pl_module for the predictions_file for the step and epoch set to "none" to disable file writing

'from_pl_module'

write(predictions, ground_truth_text, step, epoch, split, dataloader_name, pl_module, trainer)

Write predictions to a JSONL file.

Parameters:

Name Type Description Default
predictions List[Dict[str, Any]]

List of prediction dictionaries.

required
ground_truth_text Optional[List[str]]

List of ground truth text strings.

required
step int

Current training step.

required
epoch int

Current epoch.

required
split Literal['train', 'val', 'test', 'predict']

The split name (train/val/test/predict).

required
dataloader_name str

The dataloader name.

required
pl_module LightningModule

The Lightning module.

required
trainer Optional[Trainer]

The Lightning trainer.

required

read(step, epoch, split, dataloader_name, pl_module)

Read predictions from a JSONL file.

LoggerPredictionWriter

Bases: _PredictionWriter

Writer that outputs predictions to Lightning loggers.

__init__(n_rows=10, logger_=None, fields_to_keep_in_output=None)

Initialize LoggerPredictionWriter.

Parameters:

Name Type Description Default
n_rows int

Number of rows to log to the logger.

10
logger_ Optional[List[Logger]]

List of loggers to use. If None, uses pl_module.trainer.loggers.

None
fields_to_keep_in_output Optional[List[str]]

List of fields to keep in the output. If None, all fields are kept.

None

write(predictions, ground_truth_text, step, epoch, split, dataloader_name, pl_module, trainer)

Write predictions to Lightning loggers.

Parameters:

Name Type Description Default
predictions List[Dict[str, Any]]

List of prediction dictionaries.

required
ground_truth_text List[str]

List of ground truth text strings.

required
step int

Current training step.

required
epoch int

Current epoch.

required
split Literal['train', 'val', 'test', 'predict']

The split name (train/val/test/predict).

required
dataloader_name str

The dataloader name.

required
pl_module LightningModule

The Lightning module.

required
trainer Optional[Trainer]

The Lightning trainer.

required

ConsolePredictionWriter

Bases: _PredictionWriter

Writer that outputs predictions to console.

__init__(fields_to_keep_in_output=None)

Initialize ConsolePredictionWriter.

Parameters:

Name Type Description Default
fields_to_keep_in_output Optional[List[str]]

List of fields to keep in the output. If None, all fields are kept.

None

write(predictions, ground_truth_text, step, epoch, split, dataloader_name, pl_module, trainer)

Write predictions to console.

Parameters:

Name Type Description Default
predictions List[Dict[str, Any]]

List of prediction dictionaries.

required
ground_truth_text List[str]

List of ground truth text strings.

required
step int

Current training step.

required
epoch int

Current epoch.

required
split Literal['train', 'val', 'test', 'predict']

The split name (train/val/test/predict).

required
dataloader_name str

The dataloader name.

required
pl_module LightningModule

The Lightning module.

required
trainer Optional[Trainer]

The Lightning trainer.

required

LogPredictions

Main logging class that handles the shared pipeline and delegates to writers.

__init__(writers=None, inject_target=None, additional_fields_from_batch=None, fields_to_keep_in_output=None)

Initialize LogPredictions.

Parameters:

Name Type Description Default
writers Optional[Union[List[_PredictionWriter], List[Literal['file', 'logger', 'console']]]]

List of prediction writers. If None, creates default writers for backward compatibility.

None
inject_target Optional[str]

Key in batch to use as ground truth. If None, empty strings are used.

None

__call__(pl_module, trainer, batch, preds, split, dataloader_name)

Log predictions using the shared pipeline and delegate to writers.

Parameters:

Name Type Description Default
pl_module LightningModule

The Lightning module.

required
trainer Optional[Trainer]

The Lightning trainer.

required
batch Dict[str, Any]

The input batch.

required
preds Dict[str, Any]

The predictions.

required
split Literal['train', 'val', 'test', 'predict']

The split name (train/val/test/predict).

required
dataloader_name str

The dataloader name.

required

read(step, epoch, split, dataloader_name, pl_module)

Read predictions from the writers.