Skip to content

ilm.nn

remove_tokens(token_ids, ids_to_remove, pad_token_id)

Remove all ids_to_remove (e.g. mask tokens) from token_ids and shift the non-mask tokens to fill the gap. The resulting tensor has the same shape as token_ids. The extra (empty) slots at the end of each row are filled with pad_token_id.

Parameters:

Name Type Description Default
token_ids Tensor

Tensor of shape (batch, seq_len) containing token ids.

required
ids_to_remove int

The id of the mask token that should be removed or a tensor of shape (n,) containing the ids to remove.

required
pad_token_id int

The id to use for padding the empty positions.

required
hold_mask bool

For the positions where this is true, we will consider them as tokens even if they are in ids_to_remove.

required

Returns:

Type Description
Integer[Tensor, 'batch seq_len']

torch.Tensor: A tensor of the same shape as token_ids with ids_to_remove removed.

masked_ce_last_two_dims(logits, target, mask, min_value, inplace=False)

Computes cross entropy using logits and target probabilities.

The mask entries of target are ignored by setting them of -infty (effectively). Ideally, pytorch should handle -infy in the logits values that represent 0 predicted probability, but it currenlty does not: https://github.com/pytorch/pytorch/issues/49844

Note: If inplace is True, the logits will not be usable after this call.

Parameters:

Name Type Description Default
logits Float[Tensor, 'batch seq vocab']

Unnormalized logits of shape (*batch, seq, vocab).

required
target Integer[Tensor, 'batch seq']

Target probabilities of shape (*batch, seq).

required
mask Bool[Tensor, 'batch seq vocab']

Mask of shape (*batch, seq).

required
min_value float

The minimum value to use for the logits.

required
inplace bool

If True, the logits will be modified in place.

False

topk_over_last_two_dims(tensor, k)

Compute top-k values and their indices over dimensions 1 and 2 of a 3D tensor.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor of shape (batch_size, dim1, dim2).

required
k int

Number of top elements to retrieve.

required

Returns:

Type Description
Float[Tensor, '*batch k']

torch.Tensor: Top-k values, shape (batch_size, k).

Float[Tensor, '*batch k 2']

torch.Tensor: Unraveled indices of top-k values, shape (batch_size, k, 2).

max_over_last_two_dims(x)

Compute the maximum values and their indices over dimensions 1 and 2 of a 3D tensor.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor of shape (batch_size, dim1, dim2).

required

Returns:

Type Description
Float[Tensor, '*batch']

torch.Tensor: Maximum values, shape (batch_size,).

Tuple[Integer[Tensor, '*batch'], Integer[Tensor, '*batch']]

torch.Tensor: Unraveled indices of maximum values, shape (batch_size, 2).

sample_over_last_two_dims(logits, sampling_function)

Sample values and their indices over dimensions 1 and 2 of a 3D tensor.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor of shape (batch_size, dim1, dim2). It can represent probabilities or unnormalized logits.

required

Returns:

Type Description
Tuple[Integer[Tensor, ' *batch'], Integer[Tensor, ' *batch']]

Tuple[Integer[TT, "batch"], Integer[TT, "batch"]]: - Sampled values, shape (batch_size,). - Unraveled indices of sampled values, shape (2, batch_size).

general_sample_over_last_two_dims(logits, sampling_function, second_sampling_function)

Parameters:

Name Type Description Default
logits Float[Tensor, '*batch seq vocab']

Joint logits of shape (*batch, seq, vocab).

required
sampling_function Callable[[Float[Tensor, '*batch cat']], Integer[Tensor, '*batch']]

If second_sampling_fuction is None, this will be used to jointly sample the sequence and vocabulary dimensions. If second_sampling_function is not None, this will be used to sample from the vocab dimension.

required
second_sampling_function Optional[Callable[[Float[Tensor, '*batch cat']], Integer[Tensor, '*batch']]]

If given, it will be use for the sequence dimension.

required

Returns: sequence_indices, vocabulary_indices