Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Any, Dict, Optional, Tuple, Union | |
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | |
from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
from safetensors.torch import load_file | |
logger = logging.get_logger(__name__) | |
def decode_latents(pipe, latents): | |
video = pipe.decode_latents(latents) | |
video = pipe.video_processor.postprocess_video(video=video, output_type="np") | |
return video | |
def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
""" | |
Create an attention mask to block text from attending to alpha. | |
Args: | |
text_length: Length of the text sequence. | |
seq_length: Length of the other sequence. | |
device: The device where the mask will be stored. | |
dtype: The data type of the mask tensor. | |
Returns: | |
An attention mask tensor. | |
""" | |
total_length = text_length + seq_length | |
dense_mask = torch.ones((total_length, total_length), dtype=torch.bool) | |
dense_mask[:text_length, text_length + seq_length // 2:] = False | |
return dense_mask.to(device=device, dtype=dtype) | |
class RGBALoRACogVideoXAttnProcessor: | |
r""" | |
Processor for implementing scaled dot-product attention for the CogVideoX model. | |
It applies a rotary embedding on query and key vectors, but does not include spatial normalization. | |
""" | |
def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.") | |
# Initialize LoRA layers | |
self.lora_alpha = lora_alpha | |
self.lora_rank = lora_rank | |
# Helper function to create LoRA layers | |
def create_lora_layer(in_dim, mid_dim, out_dim): | |
return nn.Sequential( | |
nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype), | |
nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype) | |
) | |
self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
# Store attention mask | |
self.attention_mask = attention_mask | |
def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling): | |
"""Applies LoRA updates to query, key, and value tensors.""" | |
query_delta = self.to_q_lora(hidden_states).to(query.device) | |
query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling | |
key_delta = self.to_k_lora(hidden_states).to(key.device) | |
key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling | |
value_delta = self.to_v_lora(hidden_states).to(value.device) | |
value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling | |
return query, key, value | |
def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn): | |
"""Applies rotary embeddings to query and key tensors.""" | |
from diffusers.models.embeddings import apply_rotary_emb | |
# Apply rotary embedding to RGB and alpha sections | |
query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( | |
query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) | |
query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( | |
query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) | |
if not attn.is_cross_attention: | |
key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( | |
key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) | |
key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( | |
key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) | |
return query, key | |
def __call__( | |
self, | |
attn, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
# Concatenate encoder and decoder hidden states | |
text_seq_length = encoder_hidden_states.size(1) | |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
batch_size, sequence_length, _ = hidden_states.shape | |
seq_len = hidden_states.shape[1] - text_seq_length | |
scaling = self.lora_alpha / self.lora_rank | |
# Apply LoRA to query, key, value | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling) | |
# Reshape query, key, value for multi-head attention | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# Normalize query and key if required | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
# Apply rotary embeddings if provided | |
if image_rotary_emb is not None: | |
query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn) | |
# Compute scaled dot-product attention | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False | |
) | |
# Reshape the output tensor back to the original shape | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
# Apply linear projection and LoRA to the output | |
original_hidden_states = attn.to_out[0](hidden_states) | |
hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device) | |
original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling | |
# Apply dropout | |
hidden_states = attn.to_out[1](original_hidden_states) | |
# Split back into encoder and decoder hidden states | |
encoder_hidden_states, hidden_states = hidden_states.split( | |
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 | |
) | |
return hidden_states, encoder_hidden_states | |
def prepare_for_rgba_inference( | |
model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype, | |
lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100 | |
): | |
def load_lora_sequential_weights(lora_layer, lora_layers, prefix): | |
lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]}) | |
lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]}) | |
rgba_weights = load_file(rgba_weights_path) | |
aux_emb = rgba_weights['domain_emb'] | |
attention_mask = create_attention_mask(text_length, seq_length, device, dtype) | |
attn_procs = {} | |
for name in model.attn_processors.keys(): | |
attn_processor = RGBALoRACogVideoXAttnProcessor( | |
device=device, dtype=dtype, attention_mask=attention_mask, | |
lora_rank=lora_rank, lora_alpha=lora_alpha | |
) | |
index = name.split('.')[1] | |
base_prefix = f'transformer.transformer_blocks.{index}.attn1' | |
for lora_layer, prefix in [ | |
(attn_processor.to_q_lora, f'{base_prefix}.to_q'), | |
(attn_processor.to_k_lora, f'{base_prefix}.to_k'), | |
(attn_processor.to_v_lora, f'{base_prefix}.to_v'), | |
(attn_processor.to_out_lora, f'{base_prefix}.to_out.0'), | |
]: | |
load_lora_sequential_weights(lora_layer, rgba_weights, prefix) | |
attn_procs[name] = attn_processor | |
model.set_attn_processor(attn_procs) | |
def custom_forward(self): | |
def forward( | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
timestep: Union[int, float, torch.LongTensor], | |
timestep_cond: Optional[torch.Tensor] = None, | |
ofs: Optional[Union[int, float, torch.LongTensor]] = None, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
attention_kwargs: Optional[Dict[str, Any]] = None, | |
return_dict: bool = True, | |
): | |
if attention_kwargs is not None: | |
attention_kwargs = attention_kwargs.copy() | |
lora_scale = attention_kwargs.pop("scale", 1.0) | |
else: | |
lora_scale = 1.0 | |
if USE_PEFT_BACKEND: | |
# weight the lora layers by setting `lora_scale` for each PEFT layer | |
scale_lora_layers(self, lora_scale) | |
else: | |
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: | |
logger.warning( | |
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." | |
) | |
batch_size, num_frames, channels, height, width = hidden_states.shape | |
# 1. Time embedding | |
timesteps = timestep | |
t_emb = self.time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=hidden_states.dtype) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
if self.ofs_embedding is not None: | |
ofs_emb = self.ofs_proj(ofs) | |
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) | |
ofs_emb = self.ofs_embedding(ofs_emb) | |
emb = emb + ofs_emb | |
# 2. Patch embedding | |
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) | |
hidden_states = self.embedding_dropout(hidden_states) | |
text_seq_length = encoder_hidden_states.shape[1] | |
encoder_hidden_states = hidden_states[:, :text_seq_length] | |
hidden_states = hidden_states[:, text_seq_length:] | |
hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype) | |
# 3. Transformer blocks | |
for i, block in enumerate(self.transformer_blocks): | |
if torch.is_grad_enabled() and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
encoder_hidden_states, | |
emb, | |
image_rotary_emb, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states, encoder_hidden_states = block( | |
hidden_states=hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
temb=emb, | |
image_rotary_emb=image_rotary_emb, | |
) | |
if not self.config.use_rotary_positional_embeddings: | |
# CogVideoX-2B | |
hidden_states = self.norm_final(hidden_states) | |
else: | |
# CogVideoX-5B | |
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
hidden_states = self.norm_final(hidden_states) | |
hidden_states = hidden_states[:, text_seq_length:] | |
# 4. Final block | |
hidden_states = self.norm_out(hidden_states, temb=emb) | |
hidden_states = self.proj_out(hidden_states) | |
# 5. Unpatchify | |
p = self.config.patch_size | |
p_t = self.config.patch_size_t | |
if p_t is None: | |
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) | |
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) | |
else: | |
output = hidden_states.reshape( | |
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p | |
) | |
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) | |
if USE_PEFT_BACKEND: | |
# remove `lora_scale` from each PEFT layer | |
unscale_lora_layers(self, lora_scale) | |
if not return_dict: | |
return (output,) | |
return Transformer2DModelOutput(sample=output) | |
return forward | |
model.forward = custom_forward(model) | |