diff --git a/examples/wan2.2/predict_s2v.py b/examples/wan2.2/predict_s2v.py index 826ef41..2c526f2 100644 --- a/examples/wan2.2/predict_s2v.py +++ b/examples/wan2.2/predict_s2v.py @@ -345,7 +345,8 @@ if lora_path is not None: pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) - pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") + if transformer_2 is not None: + pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") def save_results(): if not os.path.exists(save_path): diff --git a/scripts/cogvideox_fun/train.py b/scripts/cogvideox_fun/train.py index 7401b61..1ef9969 100755 --- a/scripts/cogvideox_fun/train.py +++ b/scripts/cogvideox_fun/train.py @@ -1229,9 +1229,9 @@ def collate_fn(examples): ema_transformer3d.to(accelerator.device) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/cogvideox_fun/train_control.py b/scripts/cogvideox_fun/train_control.py index a446138..3ebdb18 100755 --- a/scripts/cogvideox_fun/train_control.py +++ b/scripts/cogvideox_fun/train_control.py @@ -1164,9 +1164,9 @@ def collate_fn(examples): ema_transformer3d.to(accelerator.device) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/cogvideox_fun/train_lora.py b/scripts/cogvideox_fun/train_lora.py index 7a43b22..e00cffc 100755 --- a/scripts/cogvideox_fun/train_lora.py +++ b/scripts/cogvideox_fun/train_lora.py @@ -1164,10 +1164,10 @@ def collate_fn(examples): ) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/fantasytalking/train.py b/scripts/fantasytalking/train.py index 728a655..b73effe 100644 --- a/scripts/fantasytalking/train.py +++ b/scripts/fantasytalking/train.py @@ -1357,7 +1357,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) audio_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=torch.float32) diff --git a/scripts/flux/train.py b/scripts/flux/train.py index 3b0be2c..02bac96 100644 --- a/scripts/flux/train.py +++ b/scripts/flux/train.py @@ -1348,8 +1348,8 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") - text_encoder_2.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + text_encoder_2.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/flux/train_lora.py b/scripts/flux/train_lora.py index babf2d4..d447c37 100644 --- a/scripts/flux/train_lora.py +++ b/scripts/flux/train_lora.py @@ -1280,11 +1280,11 @@ def _create_special_list(length): # text_encoder_2 = shard_fn(text_encoder_2) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) - text_encoder_2.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + text_encoder_2.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/qwenimage/train.py b/scripts/qwenimage/train.py index d9e97f5..b98bbac 100644 --- a/scripts/qwenimage/train.py +++ b/scripts/qwenimage/train.py @@ -1215,7 +1215,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/qwenimage/train_edit.py b/scripts/qwenimage/train_edit.py index b1d30ec..4f72e58 100644 --- a/scripts/qwenimage/train_edit.py +++ b/scripts/qwenimage/train_edit.py @@ -1260,7 +1260,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/qwenimage/train_edit_lora.py b/scripts/qwenimage/train_edit_lora.py index 100e0ca..065128b 100644 --- a/scripts/qwenimage/train_edit_lora.py +++ b/scripts/qwenimage/train_edit_lora.py @@ -1209,10 +1209,10 @@ def _create_special_list(length): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/qwenimage/train_lora.py b/scripts/qwenimage/train_lora.py index b3dfc5f..d9e58e2 100644 --- a/scripts/qwenimage/train_lora.py +++ b/scripts/qwenimage/train_lora.py @@ -1157,10 +1157,10 @@ def _create_special_list(length): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.1/train.py b/scripts/wan2.1/train.py index 2d40717..7316fe4 100755 --- a/scripts/wan2.1/train.py +++ b/scripts/wan2.1/train.py @@ -1405,7 +1405,7 @@ def collate_fn(examples): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if args.train_mode != "normal": clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) diff --git a/scripts/wan2.1/train_lora.py b/scripts/wan2.1/train_lora.py index fa4092a..b97771d 100755 --- a/scripts/wan2.1/train_lora.py +++ b/scripts/wan2.1/train_lora.py @@ -1337,12 +1337,12 @@ def collate_fn(examples): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if args.train_mode != "normal": - clip_image_encoder.to(accelerator.device, dtype=weight_dtype) + clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.1_fun/train.py b/scripts/wan2.1_fun/train.py index 8fd8e4a..b6936b3 100755 --- a/scripts/wan2.1_fun/train.py +++ b/scripts/wan2.1_fun/train.py @@ -1402,7 +1402,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if args.train_mode != "normal": clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) diff --git a/scripts/wan2.1_fun/train_control.py b/scripts/wan2.1_fun/train_control.py index 8d627fc..688a48d 100755 --- a/scripts/wan2.1_fun/train_control.py +++ b/scripts/wan2.1_fun/train_control.py @@ -1408,7 +1408,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. diff --git a/scripts/wan2.1_fun/train_control_lora.py b/scripts/wan2.1_fun/train_control_lora.py index 62b5ee4..ae5387b 100755 --- a/scripts/wan2.1_fun/train_control_lora.py +++ b/scripts/wan2.1_fun/train_control_lora.py @@ -1350,11 +1350,11 @@ def _create_special_list(length): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) - clip_image_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.1_fun/train_lora.py b/scripts/wan2.1_fun/train_lora.py index ddf8bbc..8f400de 100755 --- a/scripts/wan2.1_fun/train_lora.py +++ b/scripts/wan2.1_fun/train_lora.py @@ -1338,12 +1338,12 @@ def _create_special_list(length): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if args.train_mode != "normal": - clip_image_encoder.to(accelerator.device, dtype=weight_dtype) + clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.1_fun/train_reward_lora.py b/scripts/wan2.1_fun/train_reward_lora.py index 3c0b2dd..db66cde 100755 --- a/scripts/wan2.1_fun/train_reward_lora.py +++ b/scripts/wan2.1_fun/train_reward_lora.py @@ -1054,7 +1054,7 @@ def save_model_hook(models, weights, output_dir): vae.to(accelerator.device, dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device) - clip_image_encoder.to(accelerator.device, dtype=weight_dtype) + clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(prompt_list) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.1_vace/train.py b/scripts/wan2.1_vace/train.py index 8a2c2c2..6c45fe8 100644 --- a/scripts/wan2.1_vace/train.py +++ b/scripts/wan2.1_vace/train.py @@ -1404,7 +1404,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2/train.py b/scripts/wan2.2/train.py index 932aeeb..2e84d71 100644 --- a/scripts/wan2.2/train.py +++ b/scripts/wan2.2/train.py @@ -1439,7 +1439,7 @@ def collate_fn(examples): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2/train_animate.py b/scripts/wan2.2/train_animate.py index e6d3b29..6b988f7 100644 --- a/scripts/wan2.2/train_animate.py +++ b/scripts/wan2.2/train_animate.py @@ -679,38 +679,6 @@ def parse_args(): 'The initial gradient is relative to the multiple of the max_grad_norm. ' ), ) - parser.add_argument( - "--train_mode", - type=str, - default="control", - help=( - 'The format of training data. Support `"control"`' - ' (default), `"control_ref"`, `"control_camera_ref"`.' - ), - ) - parser.add_argument( - "--control_ref_image", - type=str, - default="first_frame", - help=( - 'The format of training data. Support `"first_frame"`' - ' (default), `"random"`.' - ), - ) - parser.add_argument( - "--add_full_ref_image_in_self_attention", - action="store_true", - help=( - 'Whether enable add full ref image in self attention.' - ), - ) - parser.add_argument( - "--add_inpaint_info", - action="store_true", - help=( - 'Whether enable add inpaint info in self attention.' - ), - ) parser.add_argument( "--weighting_scheme", type=str, @@ -1464,7 +1432,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. diff --git a/scripts/wan2.2/train_animate.sh b/scripts/wan2.2/train_animate.sh index 153dbf8..ac27520 100644 --- a/scripts/wan2.2/train_animate.sh +++ b/scripts/wan2.2/train_animate.sh @@ -11,9 +11,8 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_animate.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ - --image_sample_size=1024 \ - --video_sample_size=256 \ - --token_sample_size=512 \ + --video_sample_size=640 \ + --token_sample_size=640 \ --video_sample_stride=2 \ --video_sample_n_frames=81 \ --train_batch_size=1 \ diff --git a/scripts/wan2.2/train_animate_lora.py b/scripts/wan2.2/train_animate_lora.py new file mode 100644 index 0000000..cbd7976 --- /dev/null +++ b/scripts/wan2.2/train_animate_lora.py @@ -0,0 +1,1941 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator, FullyShardedDataParallelPlugin +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data import (ImageVideoDataset, ImageVideoSampler, + VideoAnimateDataset, get_random_mask, + process_pose_file, process_pose_params) +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image_video import (ImageVideoDataset, + ImageVideoSampler, + get_random_mask) +from videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8, + CLIPModel, Wan2_2Transformer3DModel, + Wan2_2Transformer3DModel_Animate, + WanT5EncoderModel) +from videox_fun.pipeline import (Wan2_2FunControlPipeline, Wan2_2I2VPipeline, + Wan2_2Pipeline) +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.lora_utils import (create_network, merge_lora, + unmerge_lora) +from videox_fun.utils.utils import (get_image_to_video_latent, + get_video_to_video_latent, + save_videos_grid) + +if is_wandb_available(): + import wandb + + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + +def get_i2v_mask(lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2) + return msk + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, transformer3d, network, args, config, accelerator, weight_dtype, global_step): + try: + logger.info("Running validation... ") + + if args.boundary_type == "full": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + + transformer3d_val = Wan2_2Transformer3DModel_Animate.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + + transformer3d_2_val = None + else: + if args.boundary_type == "low": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + + transformer3d_val = Wan2_2Transformer3DModel_Animate.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d_2_val = Wan2_2Transformer3DModel_Animate.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + else: + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + + transformer3d_val = Wan2_2Transformer3DModel_Animate.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d_2_val = Wan2_2Transformer3DModel_Animate.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_2_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + + scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + pipeline = Wan2_2FunControlPipeline( + vae=accelerator.unwrap_model(vae).to(weight_dtype), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + transformer=transformer3d_val, + transformer_2=transformer3d_2_val, + scheduler=scheduler, + ) + pipeline = pipeline.to(accelerator.device) + + pipeline = merge_lora( + pipeline, None, 1, accelerator.device, state_dict=accelerator.unwrap_model(network).state_dict(), transformer_only=True + ) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.no_grad(): + with torch.autocast("cuda", dtype=weight_dtype): + video_length = int(args.video_sample_n_frames // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1 + input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(args.validation_paths[i], video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size]) + sample = pipeline( + args.validation_prompts[i], + num_frames = video_length, + negative_prompt = "bad detailed", + height = args.video_sample_size, + width = args.video_sample_size, + generator = generator, + + control_video = input_video, + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif")) + + del pipeline + del transformer3d_val + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + return images + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error with info {e}") + return None + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_paths", + type=str, + default=None, + nargs="+", + help=("A set of control videos evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--network_alpha", + type=int, + default=64, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.", + ) + parser.add_argument( + "--auto_tile_batch_size", action="store_true", help="Whether to auto tile batch size.", + ) + parser.add_argument( + "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss." + ) + parser.add_argument( + "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--keep_all_node_same_token_length", + action="store_true", + help="Reference of the length token.", + ) + parser.add_argument( + "--token_sample_size", + type=int, + default=512, + help="Sample size of the token.", + ) + parser.add_argument( + "--video_sample_size", + type=int, + default=512, + help="Sample size of the video.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--video_sample_stride", + type=int, + default=4, + help="Sample stride of the video.", + ) + parser.add_argument( + "--video_sample_n_frames", + type=int, + default=17, + help="Num frame of video.", + ) + parser.add_argument( + "--video_repeat", + type=int, + default=0, + help="Num of repeat video.", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help=( + "The config of the model in training." + ), + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + parser.add_argument("--save_state", action="store_true", help="Whether or not to save state.") + + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--boundary_type", + type=str, + default="low", + help=( + 'The format of training data. Support `"low"` and `"high"`' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--lora_skip_name", + type=str, + default=None, + help=("The module is not trained in loras. "), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + config = OmegaConf.load(args.config_path) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) + text_encoder = text_encoder.eval() + # Get Vae + Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 + }[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] + vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), + ) + vae.eval() + # Get Clip Image Encoder + clip_image_encoder = CLIPModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), + ) + clip_image_encoder = clip_image_encoder.eval() + + # Get Transformer + if args.boundary_type == "low" or args.boundary_type == "full": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + else: + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d = Wan2_2Transformer3DModel_Animate.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + clip_image_encoder.requires_grad_(False) + + # Lora will work with this... + network = create_network( + 1.0, + args.rank, + args.network_alpha, + text_encoder, + transformer3d, + neuron_dropout=None, + skip_name=args.lora_skip_name, + ) + network = network.to(weight_dtype) + network.apply_to(text_encoder, transformer3d, args.train_text_encoder and not args.training_with_video_token_length, True) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + network_state_dict = {} + for key in accelerate_state_dict: + if "network" in key: + network_state_dict[key.replace("network.", "")] = accelerate_state_dict[key].to(weight_dtype) + + save_file(network_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + elif zero_stage == 3: + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(models[-1])) + if not args.use_deepspeed: + for _ in range(len(weights)): + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + logging.info("Add network parameters") + trainable_params = list(filter(lambda p: p.requires_grad, network.parameters())) + trainable_params_optim = network.prepare_optimizer_params(args.learning_rate / 2, args.learning_rate, args.learning_rate) + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + sample_n_frames_bucket_interval = vae.config.temporal_compression_ratio + spatial_compression_ratio = vae.config.spatial_compression_ratio + + if args.fix_sample_size is not None and args.enable_bucket: + args.video_sample_size = max(max(args.fix_sample_size), args.video_sample_size) + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.training_with_video_token_length = False + args.random_hw_adapt = False + + # Get the dataset + train_dataset = VideoAnimateDataset( + args.train_data_meta, args.train_data_dir, + video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, + video_repeat=args.video_repeat, + enable_bucket=args.enable_bucket, + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + def get_length_to_frame_num(token_length): + if args.video_sample_size > 256: + sample_sizes = list(range(256, args.video_sample_size + 1, 128)) + + if sample_sizes[-1] != args.video_sample_size: + sample_sizes.append(args.video_sample_size) + else: + sample_sizes = [args.video_sample_size] + + length_to_frame_num = { + sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval for sample_size in sample_sizes + } + + return length_to_frame_num + + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + # Create new output + new_examples = {} + new_examples["target_token_length"] = target_token_length + new_examples["pixel_values"] = [] + new_examples["text"] = [] + # Used in Control pixel values + new_examples["control_pixel_values"] = [] + # Used in Face pixel values + new_examples["face_pixel_values"] = [] + # Used in Ref pixel values + new_examples["ref_pixel_values"] = [] + # Used in Background pixel values + new_examples["background_pixel_values"] = [] + # Used in Mask pixel values + new_examples["mask"] = [] + new_examples["clip_pixel_values"] = [] + + # Get downsample ratio in image and videos + pixel_value = examples[0]["pixel_values"] + data_type = examples[0]["data_type"] + f, h, w, c = np.shape(pixel_value) + + if args.random_hw_adapt: + if args.training_with_video_token_length: + local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples])) + # The video will be resized to a lower resolution than its own. + choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25] + if len(choice_list) == 0: + choice_list = list(length_to_frame_num.keys()) + local_video_sample_size = np.random.choice(choice_list) + batch_video_length = length_to_frame_num[local_video_sample_size] + random_downsample_ratio = args.video_sample_size / local_video_sample_size + else: + random_downsample_ratio = get_random_downsample_ratio(args.video_sample_size) + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + random_downsample_ratio = 1 + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + min_example_length = min( + [example["pixel_values"].shape[0] for example in examples] + ) + batch_video_length = int(min(batch_video_length, min_example_length)) + + # Magvae needs the number of frames to be 4n + 1. + batch_video_length = (batch_video_length - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + + if batch_video_length <= 0: + batch_video_length = 1 + + for example in examples: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + control_pixel_values = torch.from_numpy(example["control_pixel_values"]).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + + face_pixel_values = torch.from_numpy(example["face_pixel_values"]).permute(0, 3, 1, 2).contiguous() + face_pixel_values = face_pixel_values / 255. + + ref_pixel_values = torch.from_numpy(example["ref_pixel_values"]).unsqueeze(0).permute(0, 3, 1, 2).contiguous() + ref_pixel_values = ref_pixel_values / 255. + + background_pixel_values = torch.from_numpy(example["background_pixel_values"]).permute(0, 3, 1, 2).contiguous() + background_pixel_values = background_pixel_values / 255. + + mask = torch.from_numpy(example["mask"]).permute(0, 3, 1, 2).contiguous() + mask = mask / 255. + + if args.fix_sample_size is not None: + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + transform_no_normalize = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + ]) + elif args.random_ratio_crop: + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + transform_no_normalize = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + ]) + else: + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + transform_no_normalize = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + ]) + + transform_512 = transforms.Compose([ + transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop((512, 512)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + new_examples["pixel_values"].append(transform(pixel_values)[:batch_video_length]) + new_examples["control_pixel_values"].append(transform(control_pixel_values)) + new_examples["face_pixel_values"].append(transform_512(face_pixel_values)[:batch_video_length]) + new_examples["ref_pixel_values"].append(transform(ref_pixel_values)) + new_examples["background_pixel_values"].append(transform(background_pixel_values)[:batch_video_length]) + new_examples["mask"].append(transform_no_normalize(mask)[:batch_video_length]) + new_examples["clip_pixel_values"].append(torch.from_numpy(example["clip_pixel_values"])) + new_examples["text"].append(example["text"]) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + new_examples["control_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["control_pixel_values"]]) + new_examples["face_pixel_values"] = torch.stack([example for example in new_examples["face_pixel_values"]]) + new_examples["ref_pixel_values"] = torch.stack([example for example in new_examples["ref_pixel_values"]]) + new_examples["background_pixel_values"] = torch.stack([example for example in new_examples["background_pixel_values"]]) + new_examples["clip_pixel_values"] = torch.stack([example for example in new_examples["clip_pixel_values"]]) + new_examples["mask"] = torch.stack([example for example in new_examples["mask"]]) + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + encoder_hidden_states = text_encoder( + prompt_ids.input_ids + )[0] + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = encoder_hidden_states + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + if fsdp_stage != 0: + transformer3d.network = network + transformer3d = transformer3d.to(weight_dtype) + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + + if zero_stage == 3: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + transformer3d = shard_fn(transformer3d) + + if fsdp_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + text_encoder = shard_fn(text_encoder) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + transformer3d.to(accelerator.device, dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + # function for saving/removing + def save_model(ckpt_file, unwrapped_nw): + os.makedirs(args.output_dir, exist_ok=True) + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + unwrapped_nw.save_weights(ckpt_file, weight_dtype, None) + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + # Calculate the index we need + boundary = config['transformer_additional_kwargs'].get('boundary', 0.900) + split_timesteps = args.train_sampling_steps * boundary + differences = torch.abs(noise_scheduler.timesteps - split_timesteps) + closest_index = torch.argmin(differences).item() + print(f"The boundary is {boundary} and the boundary_type is {args.boundary_type}. The closest_index we calculate is {closest_index}") + if args.boundary_type == "high": + start_num_idx = 0 + train_sampling_steps = closest_index + elif args.boundary_type == "low": + start_num_idx = closest_index + train_sampling_steps = args.train_sampling_steps - closest_index + else: + start_num_idx = 0 + train_sampling_steps = args.train_sampling_steps + idx_sampling = DiscreteSampling(train_sampling_steps, start_num_idx=start_num_idx, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + control_pixel_values = batch["control_pixel_values"].cpu() + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + control_pixel_values = rearrange(control_pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, control_pixel_value, text) in enumerate(zip(pixel_values, control_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + control_pixel_value = control_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + save_videos_grid(control_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_control.gif", rescale=True) + + face_pixel_values = batch["face_pixel_values"].cpu() + face_pixel_values = rearrange(face_pixel_values, "b f c h w -> b c f h w") + for idx, (face_pixel_value, text) in enumerate(zip(face_pixel_values, texts)): + face_pixel_value = face_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(face_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_face.gif", rescale=True) + + ref_pixel_values = batch["ref_pixel_values"].cpu() + ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> b c f h w") + for idx, (ref_pixel_value, text) in enumerate(zip(ref_pixel_values, texts)): + ref_pixel_value = ref_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(ref_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_ref.gif", rescale=True) + + background_pixel_values = batch["background_pixel_values"].cpu() + background_pixel_values = rearrange(background_pixel_values, "b f c h w -> b c f h w") + for idx, (bg_pixel_value, text) in enumerate(zip(background_pixel_values, texts)): + bg_pixel_value = bg_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(bg_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_bg.gif", rescale=True) + + clip_pixel_values, mask, texts = batch['clip_pixel_values'].cpu(), batch['mask'].cpu(), batch['text'] + mask = rearrange(mask, "b f c h w -> b c f h w") + for idx, (clip_pixel_value, pixel_value, text) in enumerate(zip(clip_pixel_values, mask, texts)): + pixel_value = pixel_value[None, ...] + Image.fromarray(np.uint8(clip_pixel_value)).save(f"{args.output_dir}/sanity_check/clip_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.png") + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/mask_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.gif", rescale=True) + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + control_pixel_values = batch["control_pixel_values"].to(weight_dtype) + face_pixel_values = batch["face_pixel_values"].to(weight_dtype) + ref_pixel_values = batch["ref_pixel_values"].to(weight_dtype) + background_pixel_values = batch["background_pixel_values"].to(weight_dtype) + mask = batch["mask"].to(weight_dtype) + clip_pixel_values = batch["clip_pixel_values"].to(weight_dtype) + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1)) + control_pixel_values = torch.tile(control_pixel_values, (4, 1, 1, 1, 1)) + face_pixel_values = torch.tile(face_pixel_values, (4, 1, 1, 1, 1)) + ref_pixel_values = torch.tile(ref_pixel_values, (4, 1, 1, 1, 1)) + background_pixel_values = torch.tile(background_pixel_values, (4, 1, 1, 1, 1)) + mask = torch.tile(mask, (4, 1, 1, 1, 1)) + + clip_pixel_values = torch.tile(clip_pixel_values, (4, 1, 1, 1)) + + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1)) + else: + batch['text'] = batch['text'] * 4 + + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1)) + control_pixel_values = torch.tile(control_pixel_values, (2, 1, 1, 1, 1)) + face_pixel_values = torch.tile(face_pixel_values, (2, 1, 1, 1, 1)) + ref_pixel_values = torch.tile(ref_pixel_values, (2, 1, 1, 1, 1)) + background_pixel_values = torch.tile(background_pixel_values, (2, 1, 1, 1, 1)) + mask = torch.tile(mask, (2, 1, 1, 1, 1)) + + clip_pixel_values = torch.tile(clip_pixel_values, (2, 1, 1, 1)) + + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1)) + else: + batch['text'] = batch['text'] * 2 + + if args.random_frame_crop: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + last_element = 0.90 + remaining_sum = 1.0 - last_element + other_elements_value = remaining_sum / (length - 1) + special_list = [other_elements_value] * (length - 1) + [last_element] + return special_list + select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))] + select_frames_prob = np.array(_create_special_list(len(select_frames))) + + if len(select_frames) != 0: + if rng is None: + temp_n_frames = np.random.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = rng.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = 1 + + # Magvae needs the number of frames to be 4n + 1. + temp_n_frames = (temp_n_frames - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :temp_n_frames, :, :] + control_pixel_values = control_pixel_values[:, :temp_n_frames, :, :] + face_pixel_values = face_pixel_values[:, :temp_n_frames, :, :] + background_pixel_values = background_pixel_values[:, :temp_n_frames, :, :] + mask = mask[:, :temp_n_frames, :, :] + + # Keep all node same token length to accelerate the traning when resolution grows. + if args.keep_all_node_same_token_length: + if args.token_sample_size > 256: + numbers_list = list(range(256, args.token_sample_size + 1, 128)) + + if numbers_list[-1] != args.token_sample_size: + numbers_list.append(args.token_sample_size) + else: + numbers_list = [256] + numbers_list = [_number * _number * args.video_sample_n_frames for _number in numbers_list] + + actual_token_length = index_rng.choice(numbers_list) + actual_video_length = (min( + actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames + ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + actual_video_length = int(max(actual_video_length, 1)) + + # Magvae needs the number of frames to be 4n + 1. + actual_video_length = (actual_video_length - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :actual_video_length, :, :] + control_pixel_values = control_pixel_values[:, :actual_video_length, :, :] + face_pixel_values = face_pixel_values[:, :actual_video_length, :, :] + background_pixel_values = background_pixel_values[:, :actual_video_length, :, :] + mask = mask[:, :actual_video_length, :, :] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + clip_image_encoder.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values) + else: + latents = _batch_encode_vae(pixel_values) + + control_latents = _batch_encode_vae(control_pixel_values) + ref_latents = _batch_encode_vae(ref_pixel_values) + + if rng is None: + refert_num = np.random.choice([0, 1], p = [0.75, 0.25]) + else: + refert_num = rng.choice([0, 1], p = [0.75, 0.25]) + + background_pixel_values[:, :refert_num, :] = pixel_values[:, :refert_num, :] + mask[:, :refert_num, :] = torch.zeros_like(mask[:, :refert_num, :]) + for bs_index in range(background_pixel_values.size()[0]): + if rng is None: + zero_init_background_pixel_values_conv_in = np.random.choice([0, 1], p = [0.90, 0.10]) + else: + zero_init_background_pixel_values_conv_in = rng.choice([0, 1], p = [0.90, 0.10]) + + if zero_init_background_pixel_values_conv_in: + background_pixel_values[bs_index] = background_pixel_values[bs_index] * 0 + mask[bs_index] = torch.ones_like(mask[bs_index]) + background_latents = _batch_encode_vae(background_pixel_values) + + mask = rearrange(mask, "b f c h w -> b c f h w") + mask = torch.concat( + [ + torch.repeat_interleave(mask[:, :, 0:1], repeats=4, dim=2), + mask[:, :, 1:] + ], dim=2 + ) + mask = mask.view(mask.shape[0], mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]) + mask = mask.transpose(1, 2) + mask = resize_mask(1 - mask, latents) + y_reft = torch.concat([mask, background_latents], dim=1).to(device=accelerator.device, dtype=weight_dtype) + + mask_ref = get_i2v_mask(1, latents.size(-2), latents.size(-1), 1, device=accelerator.device) + y_ref = torch.concat([mask_ref, ref_latents], dim=1).to(device=accelerator.device, dtype=weight_dtype) + y = torch.concat([y_ref, y_reft], dim=2) + + face_pixel_values = rearrange(face_pixel_values, "b c f h w -> b f c h w") + + clip_context = [] + for clip_pixel_value in clip_pixel_values: + clip_image = Image.fromarray(np.uint8(clip_pixel_value.float().cpu().numpy())) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(clip_image_encoder.device, weight_dtype) + _clip_context = clip_image_encoder([clip_image[:, None, :, :]]) + clip_context.append(_clip_context) + clip_context = torch.cat(clip_context) + + latents = torch.cat([ref_latents, latents], dim=2) + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + if args.low_vram: + vae.to('cpu') + clip_image_encoder.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) + else: + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + torch.cuda.empty_cache() + + bsz, channel, num_frames, height, width = latents.size() + noise = torch.randn( + latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype + ) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + # Add noise + target = noise - latents + + target_shape = (vae.latent_channels, num_frames, width, height) + seq_len = math.ceil( + (target_shape[2] * target_shape[3]) / + (accelerator.unwrap_model(transformer3d).config.patch_size[1] * accelerator.unwrap_model(transformer3d).config.patch_size[2]) * + target_shape[1] + ) + + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred = transformer3d( + x=noisy_latents, + context=prompt_embeds, + t=timesteps, + seq_len=seq_len, + y=y, + clip_fea=clip_context, + pose_latents=control_latents, + face_pixel_values=face_pixel_values, + ) + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + loss = custom_mse_loss(noise_pred.float(), target.float(), weighting.float()) + loss = loss.mean() + + if args.motion_sub_loss and noise_pred.size()[2] > 2: + gt_sub_noise = noise_pred[:, :, 1:].float() - noise_pred[:, :, :-1].float() + pre_sub_noise = target[:, :, 1:].float() - target[:, :, :-1].float() + sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean") + loss = loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + network, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + network, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + accelerator.end_training() + +if __name__ == "__main__": + main() diff --git a/scripts/wan2.2/train_animate_lora.sh b/scripts/wan2.2/train_animate_lora.sh new file mode 100644 index 0000000..e4ed2fd --- /dev/null +++ b/scripts/wan2.2/train_animate_lora.sh @@ -0,0 +1,38 @@ +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Animate-14B/" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_animate_lora.py \ + --config_path="config/wan2.2/wan_civitai_animate.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --boundary_type="full" \ + --low_vram \ No newline at end of file diff --git a/scripts/wan2.2/train_lora.py b/scripts/wan2.2/train_lora.py index 3753094..5ca71c3 100755 --- a/scripts/wan2.2/train_lora.py +++ b/scripts/wan2.2/train_lora.py @@ -1383,10 +1383,10 @@ def collate_fn(examples): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2/train_s2v.py b/scripts/wan2.2/train_s2v.py index da66628..ff5b410 100644 --- a/scripts/wan2.2/train_s2v.py +++ b/scripts/wan2.2/train_s2v.py @@ -1460,7 +1460,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) audio_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=torch.float32) # We need to recalculate our total training steps as the size of the training dataloader may have changed. diff --git a/scripts/wan2.2/train_s2v_lora.py b/scripts/wan2.2/train_s2v_lora.py new file mode 100644 index 0000000..04531d3 --- /dev/null +++ b/scripts/wan2.2/train_s2v_lora.py @@ -0,0 +1,2009 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator, FullyShardedDataParallelPlugin +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, FullStateDictConfig, ShardedOptimStateDictConfig, + ShardedStateDictConfig) +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + + +from videox_fun.data import (ImageVideoDataset, ImageVideoSampler, + VideoAnimateDataset, get_random_mask, + process_pose_file, process_pose_params) +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image_video import (ImageVideoDataset, + ImageVideoSampler, + get_random_mask) +from videox_fun.data.dataset_video import VideoSpeechControlDataset +from videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8, + CLIPModel, Wan2_2Transformer3DModel, + Wan2_2Transformer3DModel_Animate, + Wan2_2Transformer3DModel_S2V, WanAudioEncoder, + WanT5EncoderModel) +from videox_fun.pipeline import (Wan2_2FunControlPipeline, Wan2_2I2VPipeline, + Wan2_2Pipeline, Wan2_2S2VPipeline) +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.lora_utils import (create_network, merge_lora, + unmerge_lora) +from videox_fun.utils.utils import (get_image_to_video_latent, + get_video_to_video_latent, + save_videos_grid) + +if is_wandb_available(): + import wandb + + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, transformer3d, args, config, accelerator, weight_dtype, global_step): + try: + logger.info("Running validation... ") + + if args.boundary_type == "full": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + + transformer3d_val = Wan2_2Transformer3DModel_S2V.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + + transformer3d_2_val = None + else: + if args.boundary_type == "low": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + + transformer3d_val = Wan2_2Transformer3DModel_S2V.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d_2_val = Wan2_2Transformer3DModel_S2V.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + else: + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + + transformer3d_val = Wan2_2Transformer3DModel_S2V.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d_2_val = Wan2_2Transformer3DModel_S2V.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + ).to(weight_dtype) + transformer3d_2_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + + scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + pipeline = Wan2_2S2VPipeline( + vae=accelerator.unwrap_model(vae).to(weight_dtype), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + transformer=transformer3d_val, + transformer_2=transformer3d_2_val, + scheduler=scheduler, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.no_grad(): + if args.train_mode != "normal": + with torch.autocast("cuda", dtype=weight_dtype): + video_length = int((args.video_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1 + input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size]) + sample = pipeline( + args.validation_prompts[i], + num_frames = video_length, + negative_prompt = "bad detailed", + height = args.video_sample_size, + width = args.video_sample_size, + guidance_scale = 6.0, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif")) + + video_length = 1 + input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size]) + sample = pipeline( + args.validation_prompts[i], + num_frames = video_length, + negative_prompt = "bad detailed", + height = args.video_sample_size, + width = args.video_sample_size, + guidance_scale = 6.0, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif")) + else: + with torch.autocast("cuda", dtype=weight_dtype): + sample = pipeline( + args.validation_prompts[i], + num_frames = args.video_sample_n_frames, + negative_prompt = "bad detailed", + height = args.video_sample_size, + width = args.video_sample_size, + generator = generator + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif")) + + sample = pipeline( + args.validation_prompts[i], + num_frames = 1, + negative_prompt = "bad detailed", + height = args.video_sample_size, + width = args.video_sample_size, + generator = generator + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif")) + + del pipeline + del transformer3d_val + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + return images + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error with info {e}") + return None + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--network_alpha", + type=int, + default=64, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.", + ) + parser.add_argument( + "--auto_tile_batch_size", action="store_true", help="Whether to auto tile batch size.", + ) + parser.add_argument( + "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss." + ) + parser.add_argument( + "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--keep_all_node_same_token_length", + action="store_true", + help="Reference of the length token.", + ) + parser.add_argument( + "--token_sample_size", + type=int, + default=512, + help="Sample size of the token.", + ) + parser.add_argument( + "--video_sample_size", + type=int, + default=512, + help="Sample size of the video.", + ) + parser.add_argument( + "--motion_frames", + type=int, + default=73, + help="Motion frames of s2v.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--video_sample_stride", + type=int, + default=4, + help="Sample stride of the video.", + ) + parser.add_argument( + "--video_sample_n_frames", + type=int, + default=17, + help="Num frame of video.", + ) + parser.add_argument( + "--video_repeat", + type=int, + default=0, + help="Num of repeat video.", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help=( + "The config of the model in training." + ), + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument( + '--trainable_modules_low_learning_rate', + nargs='+', + default=[], + help='Enter a list of trainable modules with lower learning rate' + ) + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--boundary_type", + type=str, + default="low", + help=( + 'The format of training data. Support `"low"` and `"high"`' + ), + ) + parser.add_argument( + "--control_ref_image", + type=str, + default="first_frame", + help=( + 'The format of training data. Support `"first_frame"`' + ' (default), `"random"`.' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--lora_skip_name", + type=str, + default=None, + help=("The module is not trained in loras. "), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + config = OmegaConf.load(args.config_path) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) + text_encoder = text_encoder.eval() + # Get Vae + Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 + }[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] + vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), + ) + vae.eval() + # Get Audio Encoder + audio_encoder = WanAudioEncoder( + os.path.join(args.pretrained_model_name_or_path, config['audio_encoder_kwargs'].get('audio_encoder_subpath', 'audio_encoder')), + "cpu" + ) + audio_encoder.eval() + + # Get Transformer + if args.boundary_type == "low" or args.boundary_type == "full": + sub_path = config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer') + else: + sub_path = config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer') + transformer3d = Wan2_2Transformer3DModel_S2V.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, sub_path), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + ).to(weight_dtype) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + # Lora will work with this... + network = create_network( + 1.0, + args.rank, + args.network_alpha, + text_encoder, + transformer3d, + neuron_dropout=None, + skip_name=args.lora_skip_name + ) + network.apply_to(text_encoder, transformer3d, args.train_text_encoder and not args.training_with_video_token_length, True) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + network_state_dict = {} + for key in accelerate_state_dict: + if "network" in key: + network_state_dict[key.replace("network.", "")] = accelerate_state_dict[key].to(weight_dtype) + + save_file(network_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + elif zero_stage == 3: + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(models[-1])) + if not args.use_deepspeed: + for _ in range(len(weights)): + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + logging.info("Add network parameters") + trainable_params = list(filter(lambda p: p.requires_grad, network.parameters())) + trainable_params_optim = network.prepare_optimizer_params(args.learning_rate / 2, args.learning_rate, args.learning_rate) + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + sample_n_frames_bucket_interval = vae.config.temporal_compression_ratio + spatial_compression_ratio = vae.config.spatial_compression_ratio + + if args.fix_sample_size is not None and args.enable_bucket: + args.video_sample_size = max(max(args.fix_sample_size), args.video_sample_size) + args.training_with_video_token_length = False + args.random_hw_adapt = False + + # Get the dataset + train_dataset = VideoSpeechControlDataset( + args.train_data_meta, args.train_data_dir, + video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, + enable_bucket=args.enable_bucket, enable_inpaint=True, enable_motion_info=True, motion_frames=args.motion_frames + ) + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + def get_length_to_frame_num(token_length): + if args.video_sample_size > 256: + sample_sizes = list(range(256, args.video_sample_size + 1, 128)) + + if sample_sizes[-1] != args.video_sample_size: + sample_sizes.append(args.video_sample_size) + else: + sample_sizes = [args.video_sample_size] + + length_to_frame_num = { + sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval for sample_size in sample_sizes + } + + return length_to_frame_num + + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + # Create new output + new_examples = {} + new_examples["target_token_length"] = target_token_length + new_examples["pixel_values"] = [] + new_examples["motion_pixel_values"] = [] + new_examples["text"] = [] + # Used in Control Mode + new_examples["control_pixel_values"] = [] + new_examples["ref_pixel_values"] = [] + new_examples["clip_idx"] = [] + + new_examples["audio"] = [] + new_examples["sample_rate"] = [] + new_examples["fps"] = [] + + # Used in Inpaint mode + new_examples["clip_pixel_values"] = [] + + # Get downsample ratio in image and videos + pixel_value = examples[0]["pixel_values"] + f, h, w, c = np.shape(pixel_value) + + if args.random_hw_adapt: + if args.training_with_video_token_length: + local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples])) + # The video will be resized to a lower resolution than its own. + choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25] + if len(choice_list) == 0: + choice_list = list(length_to_frame_num.keys()) + local_video_sample_size = np.random.choice(choice_list) + batch_video_length = length_to_frame_num[local_video_sample_size] + random_downsample_ratio = args.video_sample_size / local_video_sample_size + else: + random_downsample_ratio = get_random_downsample_ratio(args.video_sample_size) + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + random_downsample_ratio = 1 + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + min_example_length = min( + [example["pixel_values"].shape[0] for example in examples] + ) + batch_video_length = int(min(batch_video_length, min_example_length)) + + # Magvae needs the number of frames to be 4n. + batch_video_length = (batch_video_length) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + + if batch_video_length <= 0: + batch_video_length = 1 + + for example in examples: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + motion_pixel_values = torch.from_numpy(example["motion_pixel_values"]).permute(0, 3, 1, 2).contiguous() + motion_pixel_values = motion_pixel_values / 255. + + control_pixel_values = torch.from_numpy(example["control_pixel_values"]).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + + if args.fix_sample_size is not None: + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + new_examples["pixel_values"].append(transform(pixel_values)[:batch_video_length]) + new_examples["motion_pixel_values"].append(transform(motion_pixel_values)) + new_examples["control_pixel_values"].append(transform(control_pixel_values)) + new_examples["text"].append(example["text"]) + + if args.control_ref_image == "first_frame": + clip_index = 0 + else: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.40 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + number_list_prob = np.array(_create_special_list(len(new_examples["pixel_values"][-1]))) + clip_index = np.random.choice(list(range(len(new_examples["pixel_values"][-1]))), p = number_list_prob) + + ref_pixel_values = new_examples["pixel_values"][-1][clip_index].unsqueeze(0) + new_examples["ref_pixel_values"].append(ref_pixel_values) + + clip_pixel_values = new_examples["pixel_values"][-1][clip_index].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + new_examples["clip_pixel_values"].append(clip_pixel_values) + new_examples["clip_idx"].append(clip_index) + + audio_length = np.shape(example["audio"])[0] + batch_audio_length = int(audio_length / pixel_values.size()[0] * batch_video_length) + new_examples["audio"].append(example["audio"][:batch_audio_length]) + new_examples["sample_rate"].append(example["sample_rate"]) + new_examples["fps"].append(example["fps"]) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + new_examples["motion_pixel_values"] = torch.stack([example for example in new_examples["motion_pixel_values"]]) + new_examples["control_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["control_pixel_values"]]) + new_examples["ref_pixel_values"] = torch.stack([example for example in new_examples["ref_pixel_values"]]) + new_examples["clip_pixel_values"] = torch.stack([example for example in new_examples["clip_pixel_values"]]) + new_examples["clip_idx"] = torch.tensor(new_examples["clip_idx"]) + new_examples["audio"] = torch.stack([example for example in new_examples["audio"]]) + new_examples["sample_rate"] = new_examples["sample_rate"] + new_examples["fps"] = new_examples["fps"] + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + encoder_hidden_states = text_encoder( + prompt_ids.input_ids + )[0] + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = encoder_hidden_states + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + if fsdp_stage != 0: + transformer3d.network = network + transformer3d = transformer3d.to(weight_dtype) + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + + if zero_stage == 3: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + transformer3d = shard_fn(transformer3d) + + if fsdp_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + text_encoder = shard_fn(text_encoder) + + if args.use_ema: + ema_transformer3d.to(accelerator.device) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + transformer3d.to(accelerator.device, dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + audio_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=torch.float32) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + # Calculate the index we need + boundary = config['transformer_additional_kwargs'].get('boundary', 0.900) + split_timesteps = args.train_sampling_steps * boundary + differences = torch.abs(noise_scheduler.timesteps - split_timesteps) + closest_index = torch.argmin(differences).item() + if args.boundary_type == "high" or args.boundary_type == "low": + print(f"The boundary is {boundary} and the boundary_type is {args.boundary_type}. The closest_index we calculate is {closest_index}") + if args.boundary_type == "high": + start_num_idx = 0 + train_sampling_steps = closest_index + elif args.boundary_type == "low": + start_num_idx = closest_index + train_sampling_steps = args.train_sampling_steps - closest_index + else: + start_num_idx = 0 + train_sampling_steps = args.train_sampling_steps + + idx_sampling = DiscreteSampling(train_sampling_steps, start_num_idx=start_num_idx, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + ref_pixel_values = batch["ref_pixel_values"].cpu() + ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> b c f h w") + for idx, (ref_pixel_value, text) in enumerate(zip(ref_pixel_values, texts)): + ref_pixel_value = ref_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(ref_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_ref.gif", rescale=True) + + motion_pixel_values = batch["motion_pixel_values"].cpu() + motion_pixel_values = rearrange(motion_pixel_values, "b f c h w -> b c f h w") + for idx, (motion_pixel_value, text) in enumerate(zip(motion_pixel_values, texts)): + motion_pixel_value = motion_pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(motion_pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}_motion.gif", rescale=True) + + clip_pixel_values, texts = batch['clip_pixel_values'].cpu(), batch['text'] + for idx, (clip_pixel_value, text) in enumerate(zip(clip_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + Image.fromarray(np.uint8(clip_pixel_value)).save(f"{args.output_dir}/sanity_check/clip_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.png") + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + motion_pixel_values = batch["motion_pixel_values"].to(weight_dtype) + control_pixel_values = batch["control_pixel_values"].to(weight_dtype) + ref_pixel_values = batch["ref_pixel_values"].to(weight_dtype) + clip_idx = batch["clip_idx"] + audio = batch["audio"] + sample_rate = batch["sample_rate"] + fps = batch["fps"] + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1)) + motion_pixel_values = torch.tile(motion_pixel_values, (4, 1, 1, 1, 1)) + control_pixel_values = torch.tile(control_pixel_values, (4, 1, 1, 1, 1)) + ref_pixel_values = torch.tile(ref_pixel_values, (4, 1, 1, 1, 1)) + clip_idx = torch.tile(clip_idx, (4,)) + audio = torch.tile(audio, (4, 1)) + sample_rate = sample_rate * 4 + fps = fps * 4 + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1)) + else: + batch['text'] = batch['text'] * 4 + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1)) + motion_pixel_values = torch.tile(motion_pixel_values, (2, 1, 1, 1, 1)) + control_pixel_values = torch.tile(control_pixel_values, (2, 1, 1, 1, 1)) + ref_pixel_values = torch.tile(ref_pixel_values, (2, 1, 1, 1, 1)) + clip_idx = torch.tile(clip_idx, (2,)) + audio = torch.tile(audio, (2, 1)) + sample_rate = sample_rate * 2 + fps = fps * 2 + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1)) + else: + batch['text'] = batch['text'] * 2 + + # clip_pixel_values = batch["clip_pixel_values"].to(weight_dtype) + # # Increase the batch size when the length of the latent sequence of the current sample is small + # if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + # if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + # clip_pixel_values = torch.tile(clip_pixel_values, (4, 1, 1, 1)) + # elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + # clip_pixel_values = torch.tile(clip_pixel_values, (2, 1, 1, 1)) + + if args.random_frame_crop: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + last_element = 0.90 + remaining_sum = 1.0 - last_element + other_elements_value = remaining_sum / (length - 1) + special_list = [other_elements_value] * (length - 1) + [last_element] + return special_list + select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))] + select_frames_prob = np.array(_create_special_list(len(select_frames))) + + if len(select_frames) != 0: + if rng is None: + temp_n_frames = np.random.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = rng.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = 1 + + # Magvae needs the number of frames to be 4n. + temp_n_frames = (temp_n_frames) // sample_n_frames_bucket_interval + + pixel_values = pixel_values[:, :temp_n_frames, :, :] + control_pixel_values = control_pixel_values[:, :temp_n_frames, :, :] + + # Keep all node same token length to accelerate the traning when resolution grows. + if args.keep_all_node_same_token_length: + if args.token_sample_size > 256: + numbers_list = list(range(256, args.token_sample_size + 1, 128)) + + if numbers_list[-1] != args.token_sample_size: + numbers_list.append(args.token_sample_size) + else: + numbers_list = [256] + numbers_list = [_number * _number * args.video_sample_n_frames for _number in numbers_list] + + actual_token_length = index_rng.choice(numbers_list) + actual_video_length = (min( + actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames + )) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + actual_video_length = int(max(actual_video_length, 1)) + + # Magvae needs the number of frames to be 4n. + actual_video_length = (actual_video_length) // sample_n_frames_bucket_interval + + pixel_values = pixel_values[:, :actual_video_length, :, :] + control_pixel_values = control_pixel_values[:, :actual_video_length, :, :] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + audio_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + + if rng is None: + zero_tail_frames = np.random.choice([0, 1], p = [0.90, 0.10]) + else: + zero_tail_frames = rng.choice([0, 1], p = [0.90, 0.10]) + if zero_tail_frames: + if rng is None: + zero_frames_num = np.random.randint(1, control_pixel_values.size()[1]) + else: + zero_frames_num = rng.integers(1, control_pixel_values.size()[1]) + control_pixel_values[:, zero_frames_num:] = torch.ones_like(control_pixel_values[:, zero_frames_num:]) * -1 + + # Make control pixel values to zero + for bs_index in range(control_pixel_values.size()[0]): + if rng is None: + zero_init_control_latents_conv_in = np.random.choice([0, 1], p = [0.90, 0.10]) + else: + zero_init_control_latents_conv_in = rng.choice([0, 1], p = [0.90, 0.10]) + + if zero_init_control_latents_conv_in: + control_pixel_values[bs_index] = torch.ones_like(control_pixel_values[bs_index]) * -1 + # Encode control latents + pad_control_pixel_values = torch.cat([control_pixel_values[:, 0:1, :].repeat(1, 1, 1, 1, 1), control_pixel_values], dim=1) + control_latents = _batch_encode_vae(pad_control_pixel_values)[:, :, 1:] + + # Encode Reference latents + ref_latents = _batch_encode_vae(ref_pixel_values) + + # Encode Motion latents + if rng is None: + zero_motion_pixel_values = np.random.choice([0, 1], p = [0.90, 0.10]) + else: + zero_motion_pixel_values = rng.choice([0, 1], p = [0.90, 0.10]) + if zero_motion_pixel_values: + height, width = control_pixel_values.size()[-2], control_pixel_values.size()[-1] + motion_pixel_values = torch.zeros([1, args.motion_frames, 3, height, width], dtype=control_latents.dtype, device=control_latents.device) + + has_motion_pixel_values = torch.sum(motion_pixel_values) == 0 + if torch.sum(clip_idx) != 0: + init_first_frame = False + else: + if rng is None: + init_first_frame = np.random.choice([0, 1], p = [0.50, 0.50]) + else: + init_first_frame = rng.choice([0, 1], p = [0.50, 0.50]) + if init_first_frame or has_motion_pixel_values: + if not has_motion_pixel_values: + motion_pixel_values[:, -6:, :] = ref_pixel_values + + motion_frames_latents_length = int((args.motion_frames - 1) / sample_n_frames_bucket_interval + 1) + local_pixel_values = torch.cat([motion_pixel_values, pixel_values], dim = 1) + local_latents = _batch_encode_vae(local_pixel_values) + latents = local_latents[:, :, motion_frames_latents_length:] + motion_latents = local_latents[:, :, :motion_frames_latents_length] + drop_motion_frames = False + else: + local_pixel_values = torch.cat([ref_pixel_values, pixel_values], dim = 1) + latents = _batch_encode_vae(local_pixel_values) + latents = latents[:, :, 1:] + motion_latents = _batch_encode_vae(motion_pixel_values) + drop_motion_frames = True + + if args.low_vram: + vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + audio_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) + else: + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + with torch.no_grad(): + # Extract audio emb + new_audio_wav2vec_fea = [] + for index in range(len(audio)): + _audio_wav2vec_fea = audio_encoder.extract_audio_feat_without_file_load( + audio[index], sample_rate[index], return_all_layers=True + ) + new_audio_wav2vec_fea.append(_audio_wav2vec_fea) + audio_wav2vec_fea = torch.stack(new_audio_wav2vec_fea).to(device=accelerator.device, dtype=weight_dtype) + + new_audio_wav2vec_fea = [] + for index in range(len(audio)): + _audio_wav2vec_fea, num_repeat = audio_encoder.get_audio_embed_bucket_fps( + audio_wav2vec_fea[index], fps=fps[index], batch_frames=control_pixel_values.size()[1], m=0) + new_audio_wav2vec_fea.append(_audio_wav2vec_fea) + audio_wav2vec_fea = torch.stack(new_audio_wav2vec_fea).to(device=accelerator.device, dtype=weight_dtype) + + if len(audio_wav2vec_fea.shape) == 3: + audio_wav2vec_fea = audio_wav2vec_fea.permute(0, 2, 1) + elif len(audio_wav2vec_fea.shape) == 4: + audio_wav2vec_fea = audio_wav2vec_fea.permute(0, 2, 3, 1) + + for bs_index in range(audio_wav2vec_fea.size()[0]): + if rng is None: + zero_init_control_latents_conv_in = np.random.choice([0, 1], p = [0.90, 0.10]) + else: + zero_init_control_latents_conv_in = rng.choice([0, 1], p = [0.90, 0.10]) + + if zero_init_control_latents_conv_in: + audio_wav2vec_fea[bs_index] = torch.ones_like(audio_wav2vec_fea[bs_index]) * 0 + + if zero_tail_frames: + audio_wav2vec_fea[..., zero_frames_num:] = torch.zeros_like(audio_wav2vec_fea[..., zero_frames_num:]) + # audio_wav2vec_fea = audio_wav2vec_fea[..., :control_pixel_values.size()[1]] + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + audio_encoder.to("cpu") + torch.cuda.empty_cache() + + bsz, channel, num_frames, height, width = latents.size() + noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + # Add noise + target = noise - latents + + target_shape = (vae.latent_channels, num_frames, width, height) + seq_len = math.ceil( + (target_shape[2] * target_shape[3]) / + (accelerator.unwrap_model(transformer3d).config.patch_size[1] * accelerator.unwrap_model(transformer3d).config.patch_size[2]) * + target_shape[1] + ) + + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred = transformer3d( + x=noisy_latents, + context=prompt_embeds, + t=timesteps, + seq_len=seq_len, + cond_states=control_latents, + motion_latents=motion_latents, + ref_latents=ref_latents, + audio_input=audio_wav2vec_fea, + motion_frames=[[args.motion_frames, (args.motion_frames + 3) // 4]] * bsz, + drop_motion_frames=drop_motion_frames, + ) + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + loss = custom_mse_loss(noise_pred.float(), target.float(), weighting.float()) + loss = loss.mean() + + if args.motion_sub_loss and noise_pred.size()[2] > 2: + gt_sub_noise = noise_pred[:, :, 1:].float() - noise_pred[:, :, :-1].float() + pre_sub_noise = target[:, :, 1:].float() - target[:, :, :-1].float() + sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean") + loss = loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/wan2.2/train_s2v_lora.sh b/scripts/wan2.2/train_s2v_lora.sh new file mode 100644 index 0000000..e98454d --- /dev/null +++ b/scripts/wan2.2/train_s2v_lora.sh @@ -0,0 +1,39 @@ +export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-S2V-14B" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_s2v_lora.py \ + --config_path="config/wan2.2/wan_civitai_s2v.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --boundary_type="full" \ + --control_ref_image="random" \ + --low_vram \ No newline at end of file diff --git a/scripts/wan2.2_fun/train.py b/scripts/wan2.2_fun/train.py index 67eedcf..8ab36f2 100644 --- a/scripts/wan2.2_fun/train.py +++ b/scripts/wan2.2_fun/train.py @@ -1456,7 +1456,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2_fun/train_control.py b/scripts/wan2.2_fun/train_control.py index d951000..8b19edb 100644 --- a/scripts/wan2.2_fun/train_control.py +++ b/scripts/wan2.2_fun/train_control.py @@ -1513,7 +1513,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2_fun/train_control_lora.py b/scripts/wan2.2_fun/train_control_lora.py index 3a97da3..56d77de 100644 --- a/scripts/wan2.2_fun/train_control_lora.py +++ b/scripts/wan2.2_fun/train_control_lora.py @@ -1458,10 +1458,10 @@ def _create_special_list(length): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2_fun/train_lora.py b/scripts/wan2.2_fun/train_lora.py index d5a0e01..31b8b23 100644 --- a/scripts/wan2.2_fun/train_lora.py +++ b/scripts/wan2.2_fun/train_lora.py @@ -1384,10 +1384,10 @@ def _create_special_list(length): text_encoder = shard_fn(text_encoder) # Move text_encode and vae to gpu and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/scripts/wan2.2_vace_fun/train.py b/scripts/wan2.2_vace_fun/train.py index f59c067..a14a08c 100644 --- a/scripts/wan2.2_vace_fun/train.py +++ b/scripts/wan2.2_vace_fun/train.py @@ -1461,7 +1461,7 @@ def _create_special_list(length): # Move text_encode and vae to gpu and cast to weight_dtype vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device if not args.low_vram else "cpu") + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) diff --git a/videox_fun/data/__init__.py b/videox_fun/data/__init__.py index 32aff1a..babf155 100644 --- a/videox_fun/data/__init__.py +++ b/videox_fun/data/__init__.py @@ -1,5 +1,5 @@ from .dataset_image import CC15M, ImageEditDataset -from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, +from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset, ImageVideoSampler) from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, diff --git a/videox_fun/data/dataset_image_video.py b/videox_fun/data/dataset_image_video.py index 1239fcb..449a2f7 100755 --- a/videox_fun/data/dataset_image_video.py +++ b/videox_fun/data/dataset_image_video.py @@ -530,7 +530,7 @@ def get_batch(self, idx): shuffle(subject_id) subject_images = [] for i in range(min(len(subject_id), 4)): - subject_image = Image.open(subject_id[i]) + subject_image = Image.open(subject_id[i]).convert('RGB') width, height = subject_image.size total_pixels = width * height @@ -621,4 +621,37 @@ def __getitem__(self, idx): else: path = os.path.join(self.data_root, self.dataset[idx]["file_path"]) state_dict = load_file(path) - return state_dict \ No newline at end of file + return state_dict + + +class TextDataset(Dataset): + def __init__(self, ann_path, text_drop_ratio=0.0): + print(f"loading annotations from {ann_path} ...") + with open(ann_path, 'r') as f: + self.dataset = json.load(f) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + self.text_drop_ratio = text_drop_ratio + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + item = self.dataset[idx] + text = item['text'] + + # Randomly drop text (for classifier-free guidance) + if random.random() < self.text_drop_ratio: + text = '' + + sample = { + "text": text, + "idx": idx + } + return sample + + except Exception as e: + print(f"Error at index {idx}: {e}, retrying with random index...") + idx = np.random.randint(0, self.length - 1) \ No newline at end of file diff --git a/videox_fun/models/fantasytalking_transformer3d.py b/videox_fun/models/fantasytalking_transformer3d.py index e494dc2..6f7bb79 100644 --- a/videox_fun/models/fantasytalking_transformer3d.py +++ b/videox_fun/models/fantasytalking_transformer3d.py @@ -4,6 +4,7 @@ import os from typing import Any, Dict +import numpy as np import torch import torch.cuda.amp as amp import torch.nn as nn @@ -45,6 +46,10 @@ def __init__(self, context_dim, hidden_dim): nn.init.zeros_(self.k_proj.weight) nn.init.zeros_(self.v_proj.weight) + self.sp_world_size = 1 + self.sp_world_rank = 0 + self.all_gather = None + def __call__( self, attn: nn.Module, @@ -80,7 +85,14 @@ def __call__( img_x = img_x.flatten(2) if len(audio_proj.shape) == 4: - q = sequence_parallel_all_gather(q, dim=1) + if self.sp_world_size > 1: + q = self.all_gather(q, dim=1) + + length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames) + origin_length = q.size()[1] + if origin_length > length: + q_pad = q[:, length:] + q = q[:, :length] audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d] ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d) ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d) @@ -88,8 +100,11 @@ def __call__( audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL" ) audio_x = audio_x.view(b, q.size(1), n, d) + if self.sp_world_size > 1: + if origin_length > length: + audio_x = torch.cat([audio_x, q_pad], dim=1) + audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank] audio_x = audio_x.flatten(2) - audio_x = sequence_parallel_chunk(audio_x, dim=1) elif len(audio_proj.shape) == 3: ip_key = self.k_proj(audio_proj).view(b, -1, n, d) ip_value = self.v_proj(audio_proj).view(b, -1, n, d) @@ -361,6 +376,14 @@ def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length= k_lens_list, dtype=torch.long ) + def enable_multi_gpus_inference(self,): + super().enable_multi_gpus_inference() + for name, module in self.named_modules(): + if module.__class__.__name__ == 'AudioCrossAttentionProcessor': + module.sp_world_size = self.sp_world_size + module.sp_world_rank = self.sp_world_rank + module.all_gather = self.all_gather + @cfg_skip() def forward( self, diff --git a/videox_fun/models/qwenimage_transformer2d.py b/videox_fun/models/qwenimage_transformer2d.py index 41eddf0..0f99b90 100644 --- a/videox_fun/models/qwenimage_transformer2d.py +++ b/videox_fun/models/qwenimage_transformer2d.py @@ -994,30 +994,72 @@ def from_pretrained( for key in _state_dict: state_dict[key] = _state_dict[key] + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + if diffusers_version >= "0.33.0": # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. load_model_dict_into_meta( model, - state_dict, + filtered_state_dict, dtype=torch_dtype, model_name_or_path=pretrained_model_path, ) else: - model._convert_deprecated_attention_blocks(state_dict) - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" - " those weights or else make sure your checkpoint file is correct." - ) - + model._convert_deprecated_attention_blocks(filtered_state_dict) unexpected_keys = load_model_dict_into_meta( model, - state_dict, + filtered_state_dict, device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_path, diff --git a/videox_fun/models/wan_animate_adapter.py b/videox_fun/models/wan_animate_adapter.py index f2d596e..1b3dd87 100644 --- a/videox_fun/models/wan_animate_adapter.py +++ b/videox_fun/models/wan_animate_adapter.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F +import numpy as np from einops import rearrange from torch import nn @@ -337,8 +338,11 @@ def forward( motion_vec: torch.Tensor, motion_mask: Optional[torch.Tensor] = None, use_context_parallel=False, + all_gather=None, + sp_world_size=1, + sp_world_rank=0, ) -> torch.Tensor: - + dtype = x.dtype B, T, N, C = motion_vec.shape T_comp = T @@ -358,10 +362,17 @@ def forward( k = rearrange(k, "B L N H D -> (B L) N H D") v = rearrange(v, "B L N H D -> (B L) N H D") - # if use_context_parallel: - # q = gather_forward(q, dim=1) + if use_context_parallel: + q = all_gather(q, dim=1) + length = int(np.floor(q.size()[1] / T_comp) * T_comp) + origin_length = q.size()[1] + if origin_length > length: + q_pad = q[:, length:] + q = q[:, :length] + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) # Compute attention. attn = attention( q, @@ -372,8 +383,11 @@ def forward( ) attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) - # if use_context_parallel: - # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + if use_context_parallel: + q_pad = rearrange(q_pad, "B L H D -> B L (H D)") + if origin_length > length: + attn = torch.cat([attn, q_pad], dim=1) + attn = torch.chunk(attn, sp_world_size, dim=1)[sp_world_rank] output = self.linear2(attn) diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index 4812625..36d906b 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -669,6 +669,7 @@ def __init__( self.current_steps = 0 self.num_inference_steps = None self.gradient_checkpointing = False + self.all_gather = None self.sp_world_size = 1 self.sp_world_rank = 0 self.init_weights() @@ -1159,30 +1160,77 @@ def from_pretrained( for key in _state_dict: state_dict[key] = _state_dict[key] + if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size(): + model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :] + model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0 + state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight'] + + filtered_state_dict = {} + for key in state_dict: + if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size(): + filtered_state_dict[key] = state_dict[key] + else: + print(f"Skipping key '{key}' due to size mismatch or absence in model.") + + model_keys = set(model.state_dict().keys()) + loaded_keys = set(filtered_state_dict.keys()) + missing_keys = model_keys - loaded_keys + + def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None): + initialized_dict = {} + + with torch.no_grad(): + for key in missing_keys: + param_shape = model_state_dict[key].shape + param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype + if 'weight' in key: + if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']): + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'embedding' in key or 'embed' in key: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'head' in key or 'output' in key or 'proj_out' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif len(param_shape) >= 2: + initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype) + nn.init.xavier_uniform_(initialized_dict[key]) + else: + initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02 + elif 'bias' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_mean' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + elif 'running_var' in key: + initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype) + elif 'num_batches_tracked' in key: + initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long) + else: + initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype) + + return initialized_dict + + if missing_keys: + print(f"Missing keys will be initialized: {sorted(missing_keys)}") + initialized_params = initialize_missing_parameters( + missing_keys, + model.state_dict(), + torch_dtype + ) + filtered_state_dict.update(initialized_params) + if diffusers_version >= "0.33.0": # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit: # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785. load_model_dict_into_meta( model, - state_dict, + filtered_state_dict, dtype=torch_dtype, model_name_or_path=pretrained_model_path, ) else: - model._convert_deprecated_attention_blocks(state_dict) - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" - " those weights or else make sure your checkpoint file is correct." - ) - + model._convert_deprecated_attention_blocks(filtered_state_dict) unexpected_keys = load_model_dict_into_meta( model, - state_dict, + filtered_state_dict, device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_path, diff --git a/videox_fun/models/wan_transformer3d_animate.py b/videox_fun/models/wan_transformer3d_animate.py index 33682b7..227a3bd 100644 --- a/videox_fun/models/wan_transformer3d_animate.py +++ b/videox_fun/models/wan_transformer3d_animate.py @@ -46,7 +46,6 @@ def __init__( cross_attn_norm=True, eps=1e-6, motion_encoder_dim=512, - use_context_parallel=False, use_img_emb=True ): model_type = "i2v" # TODO: Hard code for both preview and official versions. @@ -54,7 +53,6 @@ def __init__( num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) self.motion_encoder_dim = motion_encoder_dim - self.use_context_parallel = use_context_parallel self.use_img_emb = use_img_emb self.pose_patch_embedding = nn.Conv3d( @@ -100,15 +98,14 @@ def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_ motion_vec = torch.cat([pad_face, motion_vec], dim=1) return x, motion_vec - def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): if block_idx % 5 == 0: - adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel] + use_context_parallel = self.sp_world_size > 1 + adapter_args = [x, motion_vec, motion_masks, use_context_parallel, self.all_gather, self.sp_world_size, self.sp_world_rank] residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) x = residual_out + x return x - @cfg_skip() def forward( self, @@ -139,6 +136,8 @@ def forward( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], @@ -227,6 +226,7 @@ def custom_forward(*inputs): t, **ckpt_kwargs, ) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) else: # arguments @@ -241,6 +241,7 @@ def custom_forward(*inputs): t=t ) x = block(x, **kwargs) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) if cond_flag: @@ -270,6 +271,7 @@ def custom_forward(*inputs): t, **ckpt_kwargs, ) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) else: # arguments @@ -284,6 +286,7 @@ def custom_forward(*inputs): t=t ) x = block(x, **kwargs) + x, motion_vec = x.to(dtype), motion_vec.to(dtype) x = self.after_transformer_block(idx, x, motion_vec) # head diff --git a/videox_fun/utils/lora_utils.py b/videox_fun/utils/lora_utils.py index 9b683c1..7ff11d4 100755 --- a/videox_fun/utils/lora_utils.py +++ b/videox_fun/utils/lora_utils.py @@ -158,7 +158,8 @@ def precalculate_safetensors_hashes(tensors, metadata): class LoRANetwork(torch.nn.Module): TRANSFORMER_TARGET_REPLACE_MODULE = [ "CogVideoXTransformer3DModel", "WanTransformer3DModel", \ - "Wan2_2Transformer3DModel", "FluxTransformer2DModel", "QwenImageTransformer2DModel" + "Wan2_2Transformer3DModel", "FluxTransformer2DModel", "QwenImageTransformer2DModel", \ + "Wan2_2Transformer3DModel_Animate", "Wan2_2Transformer3DModel_S2V", "FantasyTalkingTransformer3DModel", ] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"] LORA_PREFIX_TRANSFORMER = "lora_unet" @@ -173,6 +174,7 @@ def __init__( dropout: Optional[float] = None, module_class: Type[object] = LoRAModule, skip_name: str = None, + target_name: str = None, varbose: Optional[bool] = False, ) -> None: super().__init__() @@ -207,6 +209,15 @@ def create_modules( if skip_name is not None and skip_name in child_name: continue + + if target_name is not None: + target_name_in = False + if isinstance(target_name, str): + target_name_in = target_name in child_name + elif isinstance(target_name, list): + target_name_in = any([_target_name in child_name for _target_name in target_name]) + if not target_name_in: + continue if is_linear or is_conv2d: lora_name = prefix + "." + name + "." + child_name @@ -349,6 +360,7 @@ def create_network( transformer, neuron_dropout: Optional[float] = None, skip_name: str = None, + target_name = None, **kwargs, ): if network_dim is None: @@ -364,6 +376,7 @@ def create_network( alpha=network_alpha, dropout=neuron_dropout, skip_name=skip_name, + target_name=target_name, varbose=True, ) return network