arlm.loss_arlm
ARLMLoss
Bases: LossFunction[ARLMBatch, ARLMLossDict]
Loss function for Auto-Regressive Language Modeling (ARLM).
This loss function implements causal language modeling where the model predicts the next token given the previous tokens. The loss is computed using cross-entropy on the target sequence (which is already shifted in the batch).
For seq2seq tasks, loss is only computed on suffix tokens (non-prompt tokens).
__init__(model=None, tokenizer=None)
Initialize the ARLM loss function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Optional[ARLMModel]
|
The ARLM model to use for predictions. |
None
|
tokenizer
|
Optional[Tokenizer]
|
The tokenizer for processing tokens. |
None
|
configure(pl_module)
Configure the loss function with the lightning module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pl_module
|
Harness
|
The lightning module instance. |
required |
__call__(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Compute the loss for the given batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
ARLMBatch
|
The input batch containing input_ids, attention_mask, and target_ids. |
required |
batch_idx
|
Optional[int]
|
The batch index. |
None
|
dataloader_idx
|
Optional[int]
|
The dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
The dataloader name. |
None
|
Returns:
| Type | Description |
|---|---|
ARLMLossDict
|
Dictionary containing the loss, batch_loss, and nlls. |
loss_fn(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Compute the causal language modeling loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
ARLMBatch
|
The input batch. |
required |
batch_idx
|
Optional[int]
|
The batch index. |
None
|
dataloader_idx
|
Optional[int]
|
The dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
The dataloader name. |
None
|
Returns:
| Type | Description |
|---|---|
ARLMLossDict
|
Dictionary containing the loss, batch_loss, and nlls. |