Skip to content

xlm.backbones.dream.modeling_dream

PyTorch Dream model.

RMSNormModulations

Bases: Module

Produces the modulation parameters for RMSNorm.

__init__(cond_dim, dim, num_modulation_parameters=4)

Initializes the RMSNormModulations module.

Parameters:

Name Type Description Default
cond_dim int

The dimension of the conditioning input.

required
dim int

The hidden size.

required

forward(c)

Forward pass of the RMSNormModulations module.

Parameters:

Name Type Description Default
c Tensor

The conditioning input tensor.

required

Returns:

Type Description
List[Tensor]

Tuple[torch.Tensor]: The modulation parameters for RMSNorm. Each tensor has shape (bsz, 1, dim). When num_modulation_paramters=4, these tensors stand for the shift and scale parameters for the MHA and MLP layers, and the gating parameters: scale_msa, shift_msa, scale_mlp, shift_mlp

rms_norm_modulate(x, shift, scale) staticmethod

Applies RMSNorm modulation to the input tensor.

Parameters:

Name Type Description Default
x Tensor

The input tensor of shape (bsz, seq_len, dim).

required
shift Tensor

The shift parameter tensor of shape (bsz, 1, dim).

required
scale Tensor

The scale parameter tensor of shape (bsz, 1, dim).

required

Returns:

Type Description
Tensor

The modulated output tensor of shape (bsz, seq_len, dim).

DreamRMSNorm

Bases: Module

__init__(hidden_size, eps=1e-06)

DreamRMSNorm is equivalent to T5LayerNorm

DreamRotaryEmbedding

Bases: Module

DreamAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers".

DreamSdpaAttention

Bases: DreamAttention

Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from DreamAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API.

DreamDecoderLayer

Bases: Module

forward(hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None, position_embeddings=None, c=None, **kwargs)

Parameters:

Name Type Description Default
hidden_states `torch.FloatTensor`

input to the layer of shape (batch, seq_len, embed_dim)

required
attention_mask `torch.FloatTensor`, *optional*

attention mask of size (batch, sequence_length) where padding elements are indicated by 0.

None
output_attentions `bool`, *optional*

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

False
use_cache `bool`, *optional*

If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).

False
past_key_value `Tuple(torch.FloatTensor)`, *optional*

cached past key and value projection states

None
cache_position `torch.LongTensor` of shape `(sequence_length)`, *optional*

Indices depicting the position of the input sequence tokens in the sequence.

None
position_embeddings `Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*

Tuple containing the cosine and sine positional embeddings of shape (batch_size, seq_len, head_dim), with head_dim being the embedding dimension of each attention head.

None
kwargs `dict`, *optional*

Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model

{}

DreamBaseModel

Bases: DreamPreTrainedModel

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [DreamDecoderLayer]

Parameters:

Name Type Description Default
config DreamConfigBase

DreamConfigBase

required

rotate_half(x)

Rotates half the hidden dims of the input.

apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)

Applies Rotary Position Embedding to the query and key tensors.

Parameters:

Name Type Description Default
q `torch.Tensor`

The query tensor.

required
k `torch.Tensor`

The key tensor.

required
cos `torch.Tensor`

The cosine part of the rotary embedding.

required
sin `torch.Tensor`

The sine part of the rotary embedding.

required
position_ids `torch.Tensor`, *optional*

Deprecated and unused.

None
unsqueeze_dim `int`, *optional*, defaults to 1

The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.

1

Returns: tuple(torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding.

repeat_kv(hidden_states, n_rep)

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)