Skip to content

arlm.datamodule_arlm

ARLMEmptyDataset

Bases: IterableDataset

__init__(tokenizer, num_examples)

Initialize the ARLM empty dataset.

Parameters:

Name Type Description Default
tokenizer Tokenizer

The tokenizer to use.

required
num_examples int

Number of empty examples to generate.

required

__iter__()

Generate empty examples for ARLM training.

DefaultARLMCollator

Bases: Collator

Used for pre-training.

__init__(tokenizer, block_size, noise_schedule, truncate='block', add_eos=False)

Initialize the ARLM collator.

Parameters:

Name Type Description Default
tokenizer Tokenizer

The tokenizer to use.

required
block_size int

Maximum sequence length.

required
noise_schedule NoiseSchedule

Noise schedule (not used in ARLM but kept for interface consistency).

required
truncate Literal['max', 'block', None]

Truncation strategy.

'block'
add_eos bool

Whether to add EOS token at the end of the sequence.

False

__call__(examples)

Collate examples into a batch for ARLM training.

Parameters:

Name Type Description Default
examples List[BaseCollatorInput]

List of examples with input_ids.

required

Returns:

Type Description
ARLMBatch

ARLMBatch with input_ids, attention_mask, and target_ids.

ARLMSeq2SeqCollator

__init__(tokenizer, noise_schedule, block_size=None, input_block_size=None, add_bos=None, add_eos=False, truncate='block')

Initialize the ARLM sequence-to-sequence collator.

Parameters:

Name Type Description Default
tokenizer Tokenizer

The tokenizer to use.

required
noise_schedule NoiseSchedule

Noise schedule (not used in ARLM but kept for interface consistency).

required
block_size Optional[int]

Maximum sequence length for the target.

None
input_block_size Optional[int]

Maximum sequence length for the input.

None
add_bos Optional[str]

Where to add BOS token ("input" for prefix, "output" for after prefix, None for no BOS).

None
add_eos bool

Whether to add EOS token at the end of the suffix.

False
truncate Literal['max', 'block', None]

Truncation strategy.

'block'

__call__(examples)

Collate examples into a batch for ARLM sequence-to-sequence training.

Parameters:

Name Type Description Default
examples List[Seq2SeqCollatorInput]

List of examples with prompt_ids and input_ids.

required

Returns:

Type Description
ARLMSeq2SeqBatch

ARLMSeq2SeqBatch with input_ids, attention_mask, target_ids.

ARLMSeq2SeqPredCollator

Bases: ARLMSeq2SeqCollator

Drops all the suffix/target tokens and sends them in the target_ids of shape (batch_size, target_seq_len)

__call__(examples)

Collate examples into a batch for ARLM sequence-to-sequence prediction.

Parameters:

Name Type Description Default
examples List[Seq2SeqCollatorInput]

List of examples with prompt_ids and input_ids.

required

Returns:

Type Description
ARLMSeq2SeqBatch

ARLMSeq2SeqBatch with input_ids, attention_mask, target_ids.

prepare_prefix_ids_arlm(prefix_ids, pad_token_id, bos_token_id=None, eos_token_id=None, max_seq_len=None, truncate='block', add_bos=None, add_eos=False)

Prepare prefix ids for ARLM seq2seq tasks.

Parameters:

Name Type Description Default
prefix_ids List[List[int]]

List of prefix token sequences.

required
pad_token_id int

Padding token ID.

required
bos_token_id Optional[int]

BOS token ID.

None
eos_token_id Optional[int]

EOS token ID.

None
max_seq_len Optional[int]

Maximum sequence length.

None
truncate Literal['max', 'block', None]

Truncation strategy.

'block'
add_bos Optional[str]

Where to add BOS token ("input" for prefix, "output" for after prefix, None for no BOS).

None
add_eos bool

Whether to add EOS token at the end of the prefix.

False

Returns:

Type Description
Dict[str, List[List[int]]]

Dictionary with input_ids and attention_mask as lists.

prepare_suffix_ids_arlm(suffix_ids, pad_token_id, bos_token_id=None, eos_token_id=None, max_seq_len=None, truncate='block', add_bos=None, add_eos=False)

Prepare suffix ids for ARLM seq2seq tasks.

Parameters:

Name Type Description Default
suffix_ids List[List[int]]

List of suffix token sequences.

required
pad_token_id int

Padding token ID.

required
bos_token_id Optional[int]

BOS token ID.

None
eos_token_id Optional[int]

EOS token ID.

None
max_seq_len Optional[int]

Maximum sequence length.

None
truncate Literal['max', 'block', None]

Truncation strategy.

'block'
add_bos Optional[str]

Where to add BOS token ("input" for prefix, "output" for after prefix, None for no BOS).

None
add_eos bool

Whether to add EOS token at the end of the suffix.

False

Returns:

Type Description
Dict[str, List[List[int]]]

Dictionary with input_ids, attention_mask, and target_ids as lists.

print_batch_arlm(batch, split, tokenizer, dataloader_name='')

Print batch information for debugging ARLM batches.

Parameters:

Name Type Description Default
batch Dict[str, Any]

The batch to print.

required
split Literal['train', 'val', 'test', 'predict']

The split name.

required
tokenizer Tokenizer

The tokenizer to decode tokens.

required
dataloader_name str

Name of the dataloader.

''