TransPixar / CogVideoX /rgba_utils.py
wileewang's picture
Upload 6 files
7dc9494 verified
raw
history blame
13.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict, Optional, Tuple, Union
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from safetensors.torch import load_file
logger = logging.get_logger(__name__)
@torch.no_grad()
def decode_latents(pipe, latents):
video = pipe.decode_latents(latents)
video = pipe.video_processor.postprocess_video(video=video, output_type="np")
return video
def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""
Create an attention mask to block text from attending to alpha.
Args:
text_length: Length of the text sequence.
seq_length: Length of the other sequence.
device: The device where the mask will be stored.
dtype: The data type of the mask tensor.
Returns:
An attention mask tensor.
"""
total_length = text_length + seq_length
dense_mask = torch.ones((total_length, total_length), dtype=torch.bool)
dense_mask[:text_length, text_length + seq_length // 2:] = False
return dense_mask.to(device=device, dtype=dtype)
class RGBALoRACogVideoXAttnProcessor:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model.
It applies a rotary embedding on query and key vectors, but does not include spatial normalization.
"""
def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.")
# Initialize LoRA layers
self.lora_alpha = lora_alpha
self.lora_rank = lora_rank
# Helper function to create LoRA layers
def create_lora_layer(in_dim, mid_dim, out_dim):
return nn.Sequential(
nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype),
nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype)
)
self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim)
# Store attention mask
self.attention_mask = attention_mask
def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling):
"""Applies LoRA updates to query, key, and value tensors."""
query_delta = self.to_q_lora(hidden_states).to(query.device)
query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling
key_delta = self.to_k_lora(hidden_states).to(key.device)
key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling
value_delta = self.to_v_lora(hidden_states).to(value.device)
value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling
return query, key, value
def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn):
"""Applies rotary embeddings to query and key tensors."""
from diffusers.models.embeddings import apply_rotary_emb
# Apply rotary embedding to RGB and alpha sections
query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb(
query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb)
query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb(
query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb(
key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb)
key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb(
key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb)
return query, key
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Concatenate encoder and decoder hidden states
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = hidden_states.shape
seq_len = hidden_states.shape[1] - text_seq_length
scaling = self.lora_alpha / self.lora_rank
# Apply LoRA to query, key, value
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling)
# Reshape query, key, value for multi-head attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Normalize query and key if required
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply rotary embeddings if provided
if image_rotary_emb is not None:
query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn)
# Compute scaled dot-product attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape the output tensor back to the original shape
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# Apply linear projection and LoRA to the output
original_hidden_states = attn.to_out[0](hidden_states)
hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device)
original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling
# Apply dropout
hidden_states = attn.to_out[1](original_hidden_states)
# Split back into encoder and decoder hidden states
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
def prepare_for_rgba_inference(
model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype,
lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100
):
def load_lora_sequential_weights(lora_layer, lora_layers, prefix):
lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]})
lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]})
rgba_weights = load_file(rgba_weights_path)
aux_emb = rgba_weights['domain_emb']
attention_mask = create_attention_mask(text_length, seq_length, device, dtype)
attn_procs = {}
for name in model.attn_processors.keys():
attn_processor = RGBALoRACogVideoXAttnProcessor(
device=device, dtype=dtype, attention_mask=attention_mask,
lora_rank=lora_rank, lora_alpha=lora_alpha
)
index = name.split('.')[1]
base_prefix = f'transformer.transformer_blocks.{index}.attn1'
for lora_layer, prefix in [
(attn_processor.to_q_lora, f'{base_prefix}.to_q'),
(attn_processor.to_k_lora, f'{base_prefix}.to_k'),
(attn_processor.to_v_lora, f'{base_prefix}.to_v'),
(attn_processor.to_out_lora, f'{base_prefix}.to_out.0'),
]:
load_lora_sequential_weights(lora_layer, rgba_weights, prefix)
attn_procs[name] = attn_processor
model.set_attn_processor(attn_procs)
def custom_forward(self):
def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.ofs_embedding is not None:
ofs_emb = self.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = self.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
p = self.config.patch_size
p_t = self.config.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return forward
model.forward = custom_forward(model)