Skip to content

arlm.loss_arlm

ARLMLoss

Bases: LossFunction[ARLMBatch, ARLMLossDict]

Loss function for Auto-Regressive Language Modeling (ARLM).

This loss function implements causal language modeling where the model predicts the next token given the previous tokens. The loss is computed using cross-entropy on the target sequence (which is already shifted in the batch).

For seq2seq tasks, loss is only computed on suffix tokens (non-prompt tokens).

__init__(model=None, tokenizer=None)

Initialize the ARLM loss function.

Parameters:

Name Type Description Default
model Optional[ARLMModel]

The ARLM model to use for predictions.

None
tokenizer Optional[Tokenizer]

The tokenizer for processing tokens.

None

configure(pl_module)

Configure the loss function with the lightning module.

Parameters:

Name Type Description Default
pl_module Harness

The lightning module instance.

required

__call__(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Compute the loss for the given batch.

Parameters:

Name Type Description Default
batch ARLMBatch

The input batch containing input_ids, attention_mask, and target_ids.

required
batch_idx Optional[int]

The batch index.

None
dataloader_idx Optional[int]

The dataloader index.

None
dataloader_name Optional[str]

The dataloader name.

None

Returns:

Type Description
ARLMLossDict

Dictionary containing the loss, batch_loss, and nlls.

loss_fn(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Compute the causal language modeling loss.

Parameters:

Name Type Description Default
batch ARLMBatch

The input batch.

required
batch_idx Optional[int]

The batch index.

None
dataloader_idx Optional[int]

The dataloader index.

None
dataloader_name Optional[str]

The dataloader name.

None

Returns:

Type Description
ARLMLossDict

Dictionary containing the loss, batch_loss, and nlls.