arlm.predictor_arlm
ARLMStepResults
Bases: TypedDict
Step results for ARLM prediction.
Attributes:
| Name | Type | Description |
|---|---|---|
x |
Integer[Tensor, ' batch seq_len']
|
Integer[TT, " batch seq_len"] Current predicted sequence. |
attention_mask |
Bool[Tensor, ' batch seq_len']
|
Bool[TT, " batch seq_len"] Mask of the current sequence. |
logits |
Optional[Float[Tensor, ' batch seq_len vocab_size']]
|
Float[TT, " batch seq_len vocab_size"] Logits of the current sequence. |
ARLMPredictor
Bases: Module, Predictor[Dict[str, Any], ARLMPredictionDict]
__init__(max_steps, max_length, tokenizer=None, noise_schedule=None, sampling_method='sample', top=1000, p=0.9, model=None)
Constructor for ARLMPredictor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_steps
|
int
|
Maximum number of prediction steps. |
required |
max_length
|
int
|
Maximum sequence length. |
required |
tokenizer
|
Optional[Tokenizer]
|
The tokenizer to use. |
None
|
noise_schedule
|
Optional[NoiseSchedule]
|
Noise schedule (not used in ARLM but kept for interface consistency). |
None
|
sampling_method
|
Literal['sample', 'sample_top_k', 'sample_top_p']
|
Sampling method to use. |
'sample'
|
top
|
int
|
Top-k parameter for sampling. |
1000
|
p
|
float
|
Top-p parameter for sampling. |
0.9
|
model
|
Optional[ARLMModel]
|
The ARLM model to use for predictions. |
None
|
predict_single_step(step_results, current_step)
Predict the next token in the sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_results
|
ARLMStepResults
|
Current step results containing x, attention_mask, and logits. |
required |
current_step
|
int
|
Current prediction step. |
required |
final_step
|
Whether this is the final step. |
required |
Returns:
| Type | Description |
|---|---|
ARLMStepResults
|
Updated step results with the next token predicted. |
stop(step_results, current_length)
Check if prediction should stop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
step_results
|
ARLMStepResults
|
Current step results. |
required |
current_length
|
int
|
Current sequence length. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if prediction should stop, False otherwise. |
decode(results)
Decode the predicted sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
results
|
ARLMStepResults
|
Step results containing the predicted sequence. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[List[str], List[str], Integer[Tensor, ' batch seq_len']]
|
Tuple of (decoded_text, decoded_text_with_special_tokens, token_ids). |
predict(batch, batch_idx=None, dataloader_idx=None, dataloader_name=None, max_len=0)
Predict the complete sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch containing input_ids and attention_mask. |
required |
batch_idx
|
Optional[int]
|
Batch index. |
None
|
dataloader_idx
|
Optional[int]
|
Dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
Dataloader name. |
None
|
max_len
|
int
|
Maximum length for prediction. |
0
|
Returns:
| Type | Description |
|---|---|
ARLMPredictionDict
|
Prediction results containing text, token IDs, and attention mask. |
to_dict(batch, preds, batch_idx=None, dataloader_idx=None, dataloader_name=None)
Convert predictions to dictionary format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Dict[str, Any]
|
Input batch. |
required |
preds
|
ARLMPredictionDict
|
Prediction results. |
required |
batch_idx
|
Optional[int]
|
Batch index. |
None
|
dataloader_idx
|
Optional[int]
|
Dataloader index. |
None
|
dataloader_name
|
Optional[str]
|
Dataloader name. |
None
|
Returns:
| Type | Description |
|---|---|
List[Dict[str, Any]]
|
List of dictionaries containing prediction results. |