Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# GLIDE: https://github.com/openai/glide-text2im | |
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py | |
# -------------------------------------------------------- | |
import numpy as np | |
import functools | |
from typing import Optional, Tuple, List | |
import torch | |
import torch.nn as nn | |
import warnings | |
import math | |
import torch.nn.functional as F | |
from flash_attn import flash_attn_varlen_func | |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
from einops import rearrange | |
try: | |
from flash_attn import flash_attn_func | |
is_flash_attn = True | |
except: | |
is_flash_attn = False | |
try: | |
from apex.normalization import FusedRMSNorm as RMSNorm | |
except ImportError: | |
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") | |
class RMSNorm(torch.nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
""" | |
Initialize the RMSNorm normalization layer. | |
Args: | |
dim (int): The dimension of the input tensor. | |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | |
Attributes: | |
eps (float): A small value added to the denominator for numerical stability. | |
weight (nn.Parameter): Learnable scaling parameter. | |
""" | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x): | |
""" | |
Apply the RMSNorm normalization to the input tensor. | |
Args: | |
x (torch.Tensor): The input tensor. | |
Returns: | |
torch.Tensor: The normalized tensor. | |
""" | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
""" | |
Forward pass through the RMSNorm layer. | |
Args: | |
x (torch.Tensor): The input tensor. | |
Returns: | |
torch.Tensor: The output tensor after applying RMSNorm. | |
""" | |
output = self._norm(x.float()).type_as(x) | |
return output * self.weight | |
def modulate(x, shift, scale): | |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
def num_params(model, print_out=True, model_name="model"): | |
parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 | |
if print_out: | |
print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) | |
return parameters | |
############################################################################# | |
# Core DiT Model # | |
############################################################################# | |
class TimestepEmbedder(nn.Module): | |
""" | |
Embeds scalar timesteps into vector representations. | |
""" | |
def __init__(self, hidden_size, frequency_embedding_size=256): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
nn.SiLU(), | |
nn.Linear(hidden_size, hidden_size, bias=True), | |
) | |
self.frequency_embedding_size = frequency_embedding_size | |
def timestep_embedding(t, dim, max_period=10000): | |
""" | |
Create sinusoidal timestep embeddings. | |
:param t: a 1-D Tensor of N indices, one per batch element. | |
These may be fractional. | |
:param dim: the dimension of the output. | |
:param max_period: controls the minimum frequency of the embeddings. | |
:return: an (N, D) Tensor of positional embeddings. | |
""" | |
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
half = dim // 2 | |
freqs = torch.exp( | |
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
).to(device=t.device) | |
args = t[:, None].float() * freqs[None] | |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
if dim % 2: | |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
return embedding | |
def forward(self, t): | |
t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
t_emb = self.mlp(t_freq) | |
return t_emb | |
class Conv1DFinalLayer(nn.Module): | |
""" | |
The final layer of CrossAttnDiT. | |
""" | |
def __init__(self, hidden_size, out_channels): | |
super().__init__() | |
self.norm_final = nn.GroupNorm(16,hidden_size) | |
self.conv1d = nn.Conv1d(hidden_size, out_channels,kernel_size=1) | |
def forward(self, x): # x:(B,C,T) | |
x = self.norm_final(x) | |
x = self.conv1d(x) | |
return x | |
class ConditionEmbedder(nn.Module): | |
def __init__(self, hidden_size, context_dim): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(context_dim, hidden_size, bias=True), | |
nn.GELU(), # approximate='tanh' | |
nn.Linear(hidden_size, hidden_size, bias=True), | |
nn.LayerNorm(hidden_size) | |
) | |
def forward(self,x): | |
return self.mlp(x) | |
class Attention(nn.Module): | |
def __init__(self, dim: int, n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, y_dim: int): | |
super().__init__() | |
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads | |
model_parallel_size = 1 | |
self.n_local_heads = n_heads // model_parallel_size | |
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size | |
self.n_rep = self.n_local_heads // self.n_local_kv_heads | |
self.head_dim = dim // n_heads | |
self.wq = nn.Linear( | |
dim, n_heads * self.head_dim, bias=False | |
) | |
self.wk = nn.Linear( | |
dim, self.n_kv_heads * self.head_dim, bias=False | |
) | |
self.wv = nn.Linear( | |
dim, self.n_kv_heads * self.head_dim, bias=False | |
) | |
if y_dim > 0: | |
self.wk_y = nn.Linear( | |
y_dim, self.n_kv_heads * self.head_dim, bias=False | |
) | |
self.wv_y = nn.Linear( | |
y_dim, self.n_kv_heads * self.head_dim, bias=False | |
) | |
self.gate = nn.Parameter(torch.zeros([self.n_local_heads])) | |
self.wo = nn.Linear( | |
n_heads * self.head_dim, dim, bias=False | |
) | |
if qk_norm: | |
self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim) | |
self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim) | |
if y_dim > 0: | |
self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim) | |
else: | |
self.ky_norm = nn.Identity() | |
else: | |
self.q_norm = self.k_norm = nn.Identity() | |
self.ky_norm = nn.Identity() | |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
""" | |
Reshape frequency tensor for broadcasting it with another tensor. | |
This function reshapes the frequency tensor to have the same shape as | |
the target tensor 'x' for the purpose of broadcasting the frequency | |
tensor during element-wise operations. | |
Args: | |
freqs_cis (torch.Tensor): Frequency tensor to be reshaped. | |
x (torch.Tensor): Target tensor for broadcasting compatibility. | |
Returns: | |
torch.Tensor: Reshaped frequency tensor. | |
Raises: | |
AssertionError: If the frequency tensor doesn't match the expected | |
shape. | |
AssertionError: If the target tensor 'x' doesn't have the expected | |
number of dimensions. | |
""" | |
ndim = x.ndim | |
assert 0 <= 1 < ndim | |
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | |
shape = [d if i == 1 or i == ndim - 1 else 1 | |
for i, d in enumerate(x.shape)] | |
return freqs_cis.view(*shape) | |
def apply_rotary_emb( | |
xq: torch.Tensor, | |
xk: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply rotary embeddings to input tensors using the given frequency | |
tensor. | |
This function applies rotary embeddings to the given query 'xq' and | |
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The | |
input tensors are reshaped as complex numbers, and the frequency tensor | |
is reshaped for broadcasting compatibility. The resulting tensors | |
contain rotary embeddings and are returned as real tensors. | |
Args: | |
xq (torch.Tensor): Query tensor to apply rotary embeddings. | |
xk (torch.Tensor): Key tensor to apply rotary embeddings. | |
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex | |
exponentials. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor | |
and key tensor with rotary embeddings. | |
""" | |
with torch.cuda.amp.autocast(enabled=False): | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
# copied from huggingface modeling_llama.py | |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): | |
def _get_unpad_data(attention_mask): | |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
max_seqlen_in_batch = seqlens_in_batch.max().item() | |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
return ( | |
indices, | |
cu_seqlens, | |
max_seqlen_in_batch, | |
) | |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
key_layer = index_first_axis( | |
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
) | |
value_layer = index_first_axis( | |
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | |
) | |
if query_length == kv_seq_len: | |
query_layer = index_first_axis( | |
query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k | |
) | |
cu_seqlens_q = cu_seqlens_k | |
max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
indices_q = indices_k | |
elif query_length == 1: | |
max_seqlen_in_batch_q = 1 | |
cu_seqlens_q = torch.arange( | |
batch_size + 1, dtype=torch.int32, device=query_layer.device | |
) # There is a memcpy here, that is very bad. | |
indices_q = cu_seqlens_q[:-1] | |
query_layer = query_layer.squeeze(1) | |
else: | |
# The -q_len: slice assumes left padding. | |
attention_mask = attention_mask[:, -query_length:] | |
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | |
return ( | |
query_layer, | |
key_layer, | |
value_layer, | |
indices_q, | |
(cu_seqlens_q, cu_seqlens_k), | |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
) | |
def forward( | |
self, x: torch.Tensor, x_mask: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
y: torch.Tensor, y_mask: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Forward pass of the attention module. | |
Args: | |
x (torch.Tensor): Input tensor. | |
freqs_cis (torch.Tensor): Precomputed frequency tensor. | |
Returns: | |
torch.Tensor: Output tensor after attention. | |
""" | |
bsz, seqlen, _ = x.shape | |
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
dtype = xq.dtype | |
xq = self.q_norm(xq) | |
xk = self.k_norm(xk) | |
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | |
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | |
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | |
xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | |
xq, xk = xq.to(dtype), xk.to(dtype) | |
if is_flash_attn and dtype in [torch.float16, torch.bfloat16]: | |
# begin var_len flash attn | |
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | |
xq, xk, xv, x_mask, seqlen | |
) | |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
if self.proportional_attn: | |
softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim) | |
else: | |
softmax_scale = math.sqrt(1 / self.head_dim) | |
attn_output_unpad = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_in_batch_q, | |
max_seqlen_k=max_seqlen_in_batch_k, | |
dropout_p=0., | |
causal=False, | |
softmax_scale=softmax_scale | |
) | |
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) | |
# end var_len_flash_attn | |
else: | |
output = F.scaled_dot_product_attention( | |
xq.permute(0, 2, 1, 3), | |
xk.permute(0, 2, 1, 3), | |
xv.permute(0, 2, 1, 3), | |
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), | |
).permute(0, 2, 1, 3).to(dtype) | |
if hasattr(self, "wk_y"): # cross-attention | |
yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim) | |
yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim) | |
n_rep = self.n_local_heads // self.n_local_kv_heads | |
if n_rep >= 1: | |
yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) | |
yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) | |
output_y = F.scaled_dot_product_attention( | |
xq.permute(0, 2, 1, 3), | |
yk.permute(0, 2, 1, 3), | |
yv.permute(0, 2, 1, 3), | |
y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1) | |
).permute(0, 2, 1, 3) | |
output_y = output_y * self.gate.tanh().view(1, 1, -1, 1) | |
output = output + output_y | |
output = output.flatten(-2) | |
return self.wo(output) | |
class FinalLayer(nn.Module): | |
""" | |
The final layer of DiT. | |
""" | |
def __init__(self, hidden_size, out_channels): | |
super().__init__() | |
self.norm_final = nn.LayerNorm( | |
hidden_size, elementwise_affine=False, eps=1e-6, | |
) | |
self.linear = nn.Linear( | |
hidden_size, out_channels, bias=True | |
) | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
hidden_size, 2 * hidden_size, bias=True | |
), | |
) | |
def forward(self, x, c): | |
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) | |
x = modulate(self.norm_final(x), shift, scale) | |
x = self.linear(x) | |
return x | |
class FeedForward(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
multiple_of: int, | |
ffn_dim_multiplier: Optional[float], | |
): | |
""" | |
Initialize the FeedForward module. | |
Args: | |
dim (int): Input dimension. | |
hidden_dim (int): Hidden dimension of the feedforward layer. | |
multiple_of (int): Value to ensure hidden dimension is a multiple | |
of this value. | |
ffn_dim_multiplier (float, optional): Custom multiplier for hidden | |
dimension. Defaults to None. | |
Attributes: | |
w1 (nn.Linear): Linear transformation for the first | |
layer. | |
w2 (nn.Linear): Linear transformation for the second layer. | |
w3 (nn.Linear): Linear transformation for the third | |
layer. | |
""" | |
super().__init__() | |
hidden_dim = int(2 * hidden_dim / 3) | |
# custom dim factor multiplier | |
if ffn_dim_multiplier is not None: | |
hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
hidden_dim = multiple_of * ( | |
(hidden_dim + multiple_of - 1) // multiple_of | |
) | |
self.w1 = nn.Linear( | |
dim, hidden_dim, bias=False | |
) | |
self.w2 = nn.Linear( | |
hidden_dim, dim, bias=False | |
) | |
self.w3 = nn.Linear( | |
dim, hidden_dim, bias=False | |
) | |
# @torch.compile | |
def _forward_silu_gating(self, x1, x3): | |
return F.silu(x1) * x3 | |
def forward(self, x): | |
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) | |
class MoE(nn.Module): | |
LOAD_BALANCING_LOSSES = [] | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
num_experts: int, | |
multiple_of: int, | |
ffn_dim_multiplier: float | |
): | |
super().__init__() | |
self.num_freq_experts = num_experts | |
self.local_experts = [str(i) for i in range(num_experts)] | |
self.time_experts = nn.ModuleDict({ | |
i : FeedForward(dim, hidden_dim, multiple_of=multiple_of, | |
ffn_dim_multiplier=ffn_dim_multiplier,) for i in self.local_experts | |
}) | |
self.freq_experts = nn.ModuleDict({ | |
i: FeedForward(dim, hidden_dim, multiple_of=multiple_of, | |
ffn_dim_multiplier=ffn_dim_multiplier, ) for i in self.local_experts | |
}) | |
def forward(self, x, time): | |
orig_shape = x.shape # [B, T, 768] | |
x = x.view(-1, x.shape[-1]) # [N, 768] 按照sample1 sample2 sample3拍平 | |
expert_indices = (time // 250).unsqueeze(1).repeat(1, orig_shape[1]) | |
flat_expert_indices = expert_indices.view(-1) # [N] 找到每个expert位置 | |
# time-moe | |
y = torch.zeros_like(x) | |
for str_i, expert in self.time_experts.items(): # 找到需要用哪个expert算 | |
y[flat_expert_indices == int(str_i)] = expert(x[flat_expert_indices == int(str_i)]) | |
y = y.view(*orig_shape).to(x) | |
z = torch.zeros_like(y) | |
# frequency-moe | |
range = orig_shape[-1] // self.num_freq_experts | |
for str_i, expert in self.freq_experts.items(): # 找到需要用哪个expert算 | |
idx = int(str_i) | |
region = torch.zeros_like(z) | |
region[:, :, range * idx: range * (idx+1)] = True | |
z[:, :, range * idx: range * (idx+1)] = expert(y * region)[:, :, range * idx: range * (idx+1)] | |
return z.view(*orig_shape).to(y) | |
class TransformerBlock(nn.Module): | |
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, | |
multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, | |
qk_norm: bool, y_dim: int, num_experts) -> None: | |
super().__init__() | |
self.dim = dim | |
self.head_dim = dim // n_heads | |
self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim) | |
self.feed_forward = MoE( | |
dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, | |
ffn_dim_multiplier=ffn_dim_multiplier, num_experts=num_experts, | |
) | |
self.layer_id = layer_id | |
self.attention_norm = RMSNorm(dim, eps=norm_eps) | |
self.ffn_norm = RMSNorm(dim, eps=norm_eps) | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
dim, 6 * dim, bias=True | |
), | |
) | |
self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps) | |
def forward( | |
self, | |
x: torch.Tensor, | |
x_mask: torch.Tensor, | |
y: torch.Tensor, | |
y_mask: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
adaln_input: Optional[torch.Tensor] = None, | |
time=None | |
): | |
""" | |
Perform a forward pass through the TransformerBlock. | |
Args: | |
x (torch.Tensor): Input tensor. | |
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. | |
mask (torch.Tensor, optional): Masking tensor for attention. | |
Defaults to None. | |
Returns: | |
torch.Tensor: Output tensor after applying attention and | |
feedforward layers. | |
""" | |
if adaln_input is not None: | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ | |
self.adaLN_modulation(adaln_input).chunk(6, dim=1) | |
h = x + gate_msa.unsqueeze(1) * self.attention( | |
modulate(self.attention_norm(x), shift_msa, scale_msa), | |
x_mask, | |
freqs_cis, | |
self.attention_y_norm(y), y_mask | |
) | |
out = h + gate_mlp.unsqueeze(1) * self.feed_forward( | |
modulate(self.ffn_norm(h), shift_mlp, scale_mlp), time, | |
) | |
else: | |
h = x + self.attention( | |
self.attention_norm(x), x_mask, freqs_cis, self.attention_y_norm(y), y_mask, | |
) | |
out = h + self.feed_forward(self.ffn_norm(h), time) | |
return out | |
class VideoFlagLargeDiT(nn.Module): | |
""" | |
Diffusion model with a Transformer backbone. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
context_dim, | |
hidden_size=1152, | |
depth=28, | |
num_heads=16, | |
max_len = 1000, | |
n_kv_heads=None, | |
multiple_of: int = 256, | |
ffn_dim_multiplier: Optional[float] = None, | |
norm_eps=1e-5, | |
qk_norm=None, | |
rope_scaling_factor: float = 1., | |
ntk_factor: float = 1, | |
num_experts=8, | |
): | |
super().__init__() | |
self.in_channels = in_channels # vae dim | |
self.out_channels = in_channels | |
self.num_heads = num_heads | |
self.t_embedder = TimestepEmbedder(hidden_size) | |
self.c_embedder = ConditionEmbedder(hidden_size, context_dim) | |
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) | |
self.cap_embedder = nn.Sequential( | |
nn.LayerNorm(hidden_size), | |
nn.Linear(hidden_size, hidden_size, bias=True), | |
) | |
self.blocks = nn.ModuleList([ | |
TransformerBlock(layer_id, hidden_size, num_heads, n_kv_heads, multiple_of, | |
ffn_dim_multiplier, norm_eps, qk_norm, hidden_size, num_experts) | |
for layer_id in range(depth) | |
]) | |
self.freqs_cis = VideoFlagLargeDiT.precompute_freqs_cis(hidden_size // num_heads, max_len, | |
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor) | |
self.final_layer = FinalLayer(hidden_size, self.out_channels) | |
self.rope_scaling_factor = rope_scaling_factor | |
self.ntk_factor = ntk_factor | |
num_params(self.blocks, model_name='transformer block') | |
def forward(self, x, t, context): | |
""" | |
Forward pass of DiT. | |
x: (N, C, T) tensor of temporal inputs (latent representations of melspec) | |
t: (N,) tensor of diffusion timesteps | |
y: (N,max_tokens_len=77, context_dim) | |
""" | |
self.freqs_cis = self.freqs_cis.to(x.device) | |
x = rearrange(x, 'b c t -> b t c') | |
x = self.proj_in(x) | |
cap_mask = torch.ones((context.shape[0], context.shape[1]), dtype=torch.int32, device=x.device) # [B, T] video时一直用非mask | |
mask = torch.ones((x.shape[0], x.shape[1]), dtype=torch.int32, device=x.device) | |
t_embedding = self.t_embedder(t) # [B, 768] | |
c = self.c_embedder(context) # [B, T, 768] | |
# get pooling feature | |
cap_mask_float = cap_mask.float().unsqueeze(-1) | |
cap_feats_pool = (c * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1) | |
cap_feats_pool = cap_feats_pool.to(c) # [B, 768] | |
cap_emb = self.cap_embedder(cap_feats_pool) # [B, 768] | |
adaln_input = t_embedding + cap_emb | |
cap_mask = cap_mask.bool() | |
for block in self.blocks: | |
x = block( | |
x, mask, c, cap_mask, self.freqs_cis[:x.size(1)], | |
adaln_input=adaln_input, time=t | |
) | |
x = self.final_layer(x, adaln_input) # (N, out_channels,T) | |
x = rearrange(x, 'b t c -> b c t') | |
return x | |
def precompute_freqs_cis( | |
dim: int, | |
end: int, | |
theta: float = 10000.0, | |
rope_scaling_factor: float = 1.0, | |
ntk_factor: float = 1.0 | |
): | |
""" | |
Precompute the frequency tensor for complex exponentials (cis) with | |
given dimensions. | |
This function calculates a frequency tensor with complex exponentials | |
using the given dimension 'dim' and the end index 'end'. The 'theta' | |
parameter scales the frequencies. The returned tensor contains complex | |
values in complex64 data type. | |
Args: | |
dim (int): Dimension of the frequency tensor. | |
end (int): End index for precomputing frequencies. | |
theta (float, optional): Scaling factor for frequency computation. | |
Defaults to 10000.0. | |
Returns: | |
torch.Tensor: Precomputed frequency tensor with complex | |
exponentials. | |
""" | |
theta = theta * ntk_factor | |
print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}") | |
freqs = 1.0 / (theta ** ( | |
torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim | |
)) | |
t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore | |
t = t / rope_scaling_factor | |
freqs = torch.outer(t, freqs).float() # type: ignore | |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
return freqs_cis | |