Skip to content

arlm.predictor_arlm

ARLMStepResults

Bases: TypedDict

Step results for ARLM prediction.

Attributes:

Name Type Description
x Integer[Tensor, ' batch seq_len']

Integer[TT, " batch seq_len"] Current predicted sequence.

attention_mask Bool[Tensor, ' batch seq_len']

Bool[TT, " batch seq_len"] Mask of the current sequence.

logits Optional[Float[Tensor, ' batch seq_len vocab_size']]

Float[TT, " batch seq_len vocab_size"] Logits of the current sequence.

ARLMPredictor

Bases: Module, Predictor[Dict[str, Any], ARLMPredictionDict]

__init__(max_steps, max_length, tokenizer=None, noise_schedule=None, sampling_method='sample', top=1000, p=0.9, model=None)

Constructor for ARLMPredictor.

Parameters:

Name Type Description Default
max_steps int

Maximum number of prediction steps.

required
max_length int

Maximum sequence length.

required
tokenizer Optional[Tokenizer]

The tokenizer to use.

None
noise_schedule Optional[NoiseSchedule]

Noise schedule (not used in ARLM but kept for interface consistency).

None
sampling_method Literal['sample', 'sample_top_k', 'sample_top_p']

Sampling method to use.

'sample'
top int

Top-k parameter for sampling.

1000
p float

Top-p parameter for sampling.

0.9
model Optional[ARLMModel]

The ARLM model to use for predictions.

None

predict_single_step(step_results, current_step)

Predict the next token in the sequence.

Parameters:

Name Type Description Default
step_results ARLMStepResults

Current step results containing x, attention_mask, and logits.

required
current_step int

Current prediction step.

required
final_step

Whether this is the final step.

required

Returns:

Type Description
ARLMStepResults

Updated step results with the next token predicted.

stop(step_results, current_length)

Check if prediction should stop.

Parameters:

Name Type Description Default
step_results ARLMStepResults

Current step results.

required
current_length int

Current sequence length.

required

Returns:

Type Description
bool

True if prediction should stop, False otherwise.

decode(results)

Decode the predicted sequence.

Parameters:

Name Type Description Default
results ARLMStepResults

Step results containing the predicted sequence.

required

Returns:

Type Description
Tuple[List[str], List[str], Integer[Tensor, ' batch seq_len']]

Tuple of (decoded_text, decoded_text_with_special_tokens, token_ids).

predict(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None, max_len=0)

Predict the complete sequence.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch containing input_ids and attention_mask.

required
batch_idx Optional[int]

Batch index.

None
dataloader_idx Optional[int]

Dataloader index.

None
dataloader_name Optional[str]

Dataloader name.

None
max_len int

Maximum length for prediction.

0

Returns:

Type Description
ARLMPredictionDict

Prediction results containing text, token IDs, and attention mask.

to_dict(batch, preds, batch_idx=None, dataloader_idx=None, dataloader_name=None)

Convert predictions to dictionary format.

Parameters:

Name Type Description Default
batch Dict[str, Any]

Input batch.

required
preds ARLMPredictionDict

Prediction results.

required
batch_idx Optional[int]

Batch index.

None
dataloader_idx Optional[int]

Dataloader index.

None
dataloader_name Optional[str]

Dataloader name.

None

Returns:

Type Description
List[Dict[str, Any]]

List of dictionaries containing prediction results.