ilm.nn
remove_tokens(token_ids, ids_to_remove, pad_token_id)
Remove all ids_to_remove (e.g. mask tokens) from token_ids and shift the non-mask tokens to fill the gap. The resulting tensor has the same shape as token_ids. The extra (empty) slots at the end of each row are filled with pad_token_id.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_ids
|
Tensor
|
Tensor of shape (batch, seq_len) containing token ids. |
required |
ids_to_remove
|
int
|
The id of the mask token that should be removed or a tensor of shape (n,) containing the ids to remove. |
required |
pad_token_id
|
int
|
The id to use for padding the empty positions. |
required |
hold_mask
|
bool
|
For the positions where this is true, we will consider them as tokens even if they are in ids_to_remove. |
required |
Returns:
| Type | Description |
|---|---|
Integer[Tensor, 'batch seq_len']
|
torch.Tensor: A tensor of the same shape as token_ids with ids_to_remove removed. |
masked_ce_last_two_dims(logits, target, mask, min_value, inplace=False)
Computes cross entropy using logits and target probabilities.
The mask entries of target are ignored by setting them of -infty (effectively).
Ideally, pytorch should handle -infy in the logits values that represent 0 predicted probability,
but it currenlty does not: https://github.com/pytorch/pytorch/issues/49844
Note: If inplace is True, the logits will not be usable after this call.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Float[Tensor, 'batch seq vocab']
|
Unnormalized logits of shape (*batch, seq, vocab). |
required |
target
|
Integer[Tensor, 'batch seq']
|
Target probabilities of shape (*batch, seq). |
required |
mask
|
Bool[Tensor, 'batch seq vocab']
|
Mask of shape (*batch, seq). |
required |
min_value
|
float
|
The minimum value to use for the logits. |
required |
inplace
|
bool
|
If True, the logits will be modified in place. |
False
|
topk_over_last_two_dims(tensor, k)
Compute top-k values and their indices over dimensions 1 and 2 of a 3D tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor of shape (batch_size, dim1, dim2). |
required |
k
|
int
|
Number of top elements to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
Float[Tensor, '*batch k']
|
torch.Tensor: Top-k values, shape (batch_size, k). |
Float[Tensor, '*batch k 2']
|
torch.Tensor: Unraveled indices of top-k values, shape (batch_size, k, 2). |
max_over_last_two_dims(x)
Compute the maximum values and their indices over dimensions 1 and 2 of a 3D tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor of shape (batch_size, dim1, dim2). |
required |
Returns:
| Type | Description |
|---|---|
Float[Tensor, '*batch']
|
torch.Tensor: Maximum values, shape (batch_size,). |
Tuple[Integer[Tensor, '*batch'], Integer[Tensor, '*batch']]
|
torch.Tensor: Unraveled indices of maximum values, shape (batch_size, 2). |
sample_over_last_two_dims(logits, sampling_function)
Sample values and their indices over dimensions 1 and 2 of a 3D tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor of shape (batch_size, dim1, dim2). It can represent probabilities or unnormalized logits. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[Integer[Tensor, ' *batch'], Integer[Tensor, ' *batch']]
|
Tuple[Integer[TT, "batch"], Integer[TT, "batch"]]: - Sampled values, shape (batch_size,). - Unraveled indices of sampled values, shape (2, batch_size). |
general_sample_over_last_two_dims(logits, sampling_function, second_sampling_function)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Float[Tensor, '*batch seq vocab']
|
Joint logits of shape (*batch, seq, vocab). |
required |
sampling_function
|
Callable[[Float[Tensor, '*batch cat']], Integer[Tensor, '*batch']]
|
If second_sampling_fuction is None, this will be used to jointly sample the sequence and vocabulary dimensions. If second_sampling_function is not None, this will be used to sample from the vocab dimension. |
required |
second_sampling_function
|
Optional[Callable[[Float[Tensor, '*batch cat']], Integer[Tensor, '*batch']]]
|
If given, it will be use for the sequence dimension. |
required |
Returns: sequence_indices, vocabulary_indices