diff --git a/README.md b/README.md index 7d26dd5d..20414533 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,14 @@ To generate images, run the following command: HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 ``` + ## Wan2.2 + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + ``` ## Flux diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index 24c7b2ff..bbad3ad1 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -61,7 +61,7 @@ def create_orbax_checkpoint_manager( if checkpoint_type == FLUX_CHECKPOINT: item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config") elif checkpoint_type == WAN_CHECKPOINT: - item_names = ("wan_state", "wan_config") + item_names = ("low_noise_transformer_state", "high_noise_transformer_state", "wan_state", "wan_config") else: item_names = ( "unet_config", diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py new file mode 100644 index 00000000..de8bb35d --- /dev/null +++ b/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py @@ -0,0 +1,207 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from abc import ABC +import json + +import jax +import numpy as np +from typing import Optional, Tuple +from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) +from ..pipelines.wan.wan_pipeline2_2 import WanPipeline +from .. import max_logging, max_utils +import orbax.checkpoint as ocp +from etils import epath + +WAN_CHECKPOINT = "WAN_CHECKPOINT" + + +class WanCheckpointer(ABC): + + def __init__(self, config, checkpoint_type): + self.config = config + self.checkpoint_type = checkpoint_type + self.opt_state = None + + self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) + + def _create_optimizer(self, model, config, learning_rate): + learning_rate_scheduler = max_utils.create_learning_rate_schedule( + learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps + ) + tx = max_utils.create_optimizer(config, learning_rate_scheduler) + return tx, learning_rate_scheduler + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + low_noise_transformer_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + low_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_low_params, + ) + ) + + high_noise_transformer_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + high_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_params, + ) + ) + + max_logging.log("Restoring WAN checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + low_noise_transformer_state=low_params_restore, + high_noise_transformer_state=high_params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) + # Check for optimizer state in either transformer + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), + } + + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") + + +def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + """ + only save the config that is needed and can be serialized to JSON. + """ + if not hasattr(model_or_config, "config"): + return None + source_config = dict(model_or_config.config) + + # 1. configs that can be serialized to JSON + SAFE_KEYS = [ + "_class_name", + "_diffusers_version", + "model_type", + "patch_size", + "num_attention_heads", + "attention_head_dim", + "in_channels", + "out_channels", + "text_dim", + "freq_dim", + "ffn_dim", + "num_layers", + "cross_attn_norm", + "qk_norm", + "eps", + "image_dim", + "added_kv_proj_dim", + "rope_max_seq_len", + "pos_embed_seq_len", + "flash_min_seq_length", + "flash_block_sizes", + "attention", + "_use_default_values", + ] + + # 2. save the config that are in the SAFE_KEYS list + clean_config = {} + for key in SAFE_KEYS: + if key in source_config: + clean_config[key] = source_config[key] + + # 3. deal with special data type and precision + if "dtype" in source_config and hasattr(source_config["dtype"], "name"): + clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16' + + if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"): + clean_config["weights_dtype"] = source_config["weights_dtype"].name + + if "precision" in source_config and isinstance(source_config["precision"]): + clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST' + + return clean_config + + items_to_save = { + "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), + } + + items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) + + # Create CompositeArgs for Orbax + save_args = ocp.args.Composite(**items_to_save) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=save_args) + max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 50e66964..8dea4e3a 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -28,6 +28,7 @@ save_config_to_gcs: False log_period: 100 pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' +model_name: wan2.1 # Overrides the transformer from pretrained_model_name_or_path wan_transformer_pretrained_model_name_or_path: '' diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml new file mode 100644 index 00000000..6d005bdd --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -0,0 +1,332 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.2-T2V-A14B-Diffusers' +model_name: wan2.2 + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +dropout: 0.1 + +flash_block_sizes: { + "block_q" : 1024, + "block_kv_compute" : 256, + "block_kv" : 1024, + "block_q_dkv" : 1024, + "block_kv_dkv" : 1024, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 1024, + "block_kv_dq" : 1024 +} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# "use_fused_bwd_kernel": False, +# } +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', 'data'], + ['activation_batch', 'data'], + ['activation_length', 'fsdp'], + + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed','fsdp'], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data','fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'fsdp'], + ] +data_sharding: [['data', 'fsdp', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 1.0 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Generation parameters +prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 480 +width: 832 +num_frames: 81 +flow_shift: 3.0 + +guidance_scale_low: 5.0 +guidance_scale_high: 8.0 +boundary_timestep: 15 + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + lora_model_name_or_path: [], + weight_name: [], + adapter_name: [], + scale: [], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 501dbf32..2396dfcc 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,9 +16,9 @@ import jax import time import os -from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline from maxdiffusion import pyconfig, max_logging, max_utils from absl import app +import importlib from maxdiffusion.utils import export_to_video from google.cloud import storage import flax @@ -63,6 +63,49 @@ def delete_file(file_path: str): jax.config.update("jax_use_shardy_partitioner", True) +def get_pipeline(model_name: str): + if model_name == "wan2.1": + return importlib.import_module("maxdiffusion.pipelines.wan.wan_pipeline") + elif model_name == "wan2.2": + return importlib.import_module("maxdiffusion.pipelines.wan.wan_pipeline2_2") + else: + raise ValueError(f"Unsupported model_name in config: {model_name}") + +def get_checkpointer(model_name: str): + if model_name == "wan2.1": + return importlib.import_module("maxdiffusion.checkpointing.wan_checkpointer") + elif model_name == "wan2.2": + return importlib.import_module("maxdiffusion.checkpointing.wan_checkpointer2_2") + else: + raise ValueError(f"Unsupported model_name in config: {model_name}") + +def call_pipeline(config, pipeline, prompt, negative_prompt): + model_key = config.model_name + if model_key == "wan2.1": + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + ) + elif model_key == "wan2.2": + return pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=config.height, + width=config.width, + num_frames=config.num_frames, + num_inference_steps=config.num_inference_steps, + guidance_scale_low=config.guidance_scale_low, + guidance_scale_high=config.guidance_scale_high, + boundary=config.boundary_timestep, + ) + else: + raise ValueError(f"Unsupported model_name in config: {model_key}") + def inference_generate_video(config, pipeline, filename_prefix=""): s0 = time.perf_counter() @@ -73,15 +116,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}" ) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}") for i in range(len(videos)): @@ -96,11 +131,17 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) - from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + model_key = config.model_name + + checkpointer_lib = get_checkpointer(model_key) + WanCheckpointer = checkpointer_lib.WanCheckpointer checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") - pipeline = checkpoint_loader.load_checkpoint() + pipeline, _, _ = checkpoint_loader.load_checkpoint() + if pipeline is None: + pipeline_lib = get_pipeline(model_key) + WanPipeline = pipeline_lib.WanPipeline pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() @@ -112,15 +153,7 @@ def run(config, pipeline=None, filename_prefix=""): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) print("compile time: ", (time.perf_counter() - s0)) saved_video_path = [] @@ -132,29 +165,13 @@ def run(config, pipeline=None, filename_prefix=""): upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) s0 = time.perf_counter() - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) print("generation time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) - videos = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - height=config.height, - width=config.width, - num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, - guidance_scale=config.guidance_scale, - ) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) return saved_video_path diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index ec97abd3..191d8b61 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -184,6 +184,7 @@ def load_wan_transformer( hf_download: bool = True, num_layers: int = 40, scan_layers: bool = True, + subfolder: str = "", ): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: @@ -192,7 +193,7 @@ def load_wan_transformer( return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) else: return load_base_wan_transformer( - pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers + pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder ) @@ -203,9 +204,9 @@ def load_base_wan_transformer( hf_download: bool = True, num_layers: int = 40, scan_layers: bool = True, + subfolder: str = "", ): device = jax.local_devices(backend=device)[0] - subfolder = "transformer" filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False if os.path.isdir(pretrained_model_name_or_path): @@ -236,7 +237,7 @@ def load_base_wan_transformer( else: ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) # now get all the filenames for the model that need downloading - max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}") if ckpt_shard_path is not None: with safe_open(ckpt_shard_path, framework="pt") as f: @@ -281,7 +282,7 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") elif hf_download: ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) - max_logging.log(f"Load and port Wan 2.1 VAE on {device}") + max_logging.log(f"Load and port {pretrained_model_name_or_path} VAE on {device}") with jax.default_device(device): if ckpt_path is not None: tensors = {} diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 55981be0..cccc7eff 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -89,7 +89,7 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None + devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" ): def create_model(rngs: nnx.Rngs, wan_config: dict): @@ -142,6 +142,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): "cpu", num_layers=wan_config["num_layers"], scan_layers=config.scan_layers, + subfolder=subfolder, ) params = jax.tree_util.tree_map_with_path( @@ -353,11 +354,10 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None - ): + cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder ) return wan_transformer @@ -385,7 +385,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ if load_transformer: with mesh: transformer = cls.load_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer" ) text_encoder = cls.load_text_encoder(config=config) @@ -423,7 +423,7 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py new file mode 100644 index 00000000..0645aeeb --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py @@ -0,0 +1,725 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union, Optional +from functools import partial +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import flax +import flax.linen as nn +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from ...pyconfig import HyperParameters +from ... import max_logging +from ... import max_utils +from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated +from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae +from ...models.wan.transformers.transformer_wan import WanModel +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache +from maxdiffusion.video_processor import VideoProcessor +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState +from transformers import AutoTokenizer, UMT5EncoderModel +from maxdiffusion.utils.import_utils import is_ftfy_available +from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs +import html +import re +import torch +import qwix + + +def cast_with_exclusion(path, x, dtype_to_cast): + """ + Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32. + """ + + exclusion_keywords = [ + "norm", # For all LayerNorm/GroupNorm layers + "condition_embedder", # The entire time/text conditioning module + "scale_shift_table", # Catches both the final and the AdaLN tables + ] + + path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) + + if any(keyword in path_str.lower() for keyword in exclusion_keywords): + print("is_norm_path: ", path) + # Keep LayerNorm/GroupNorm weights and biases in full precision + return x.astype(jnp.float32) + else: + # Cast everything else to dtype_to_cast + return x.astype(dtype_to_cast) + + +def basic_clean(text): + if is_ftfy_available(): + import ftfy + + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState: + vs.sharding_rules = logical_axis_rules + return vs + + +# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. +def create_sharded_logical_transformer( + devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" +): + + def create_model(rngs: nnx.Rngs, wan_config: dict): + wan_transformer = WanModel(**wan_config, rngs=rngs) + return wan_transformer + + # 1. Load config. + if restored_checkpoint: + wan_config = restored_checkpoint["wan_config"] + else: + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) + wan_config["mesh"] = mesh + wan_config["dtype"] = config.activations_dtype + wan_config["weights_dtype"] = config.weights_dtype + wan_config["attention"] = config.attention + wan_config["precision"] = get_precision(config) + wan_config["flash_block_sizes"] = get_flash_block_sizes(config) + wan_config["remat_policy"] = config.remat_policy + wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved + wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded + wan_config["flash_min_seq_length"] = config.flash_min_seq_length + wan_config["dropout"] = config.dropout + wan_config["scan_layers"] = config.scan_layers + + # 2. eval_shape - will not use flops or create weights on device + # thus not using HBM memory. + p_model_factory = partial(create_model, wan_config=wan_config) + wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) + + # 3. retrieve the state shardings, mapping logical names to mesh axis names. + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) + params = state.to_pure_dict() + state = dict(nnx.to_flat_state(state)) + + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. + # This helps with loading sharded weights directly into the accelerators without fist copying them + # all to one device and then distributing them, thus using low HBM memory. + if restored_checkpoint: + if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer + params = restored_checkpoint["wan_state"]["params"] + else: # if not checkpointed with optimizer + params = restored_checkpoint["wan_state"] + else: + params = load_wan_transformer( + config.wan_transformer_pretrained_model_name_or_path, + params, + "cpu", + num_layers=wan_config["num_layers"], + scan_layers=config.scan_layers, + subfolder=subfolder, + ) + + params = jax.tree_util.tree_map_with_path( + lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params + ) + for path, val in flax.traverse_util.flatten_dict(params).items(): + if restored_checkpoint: + path = path[:-1] + sharding = logical_state_sharding[path].value + state[path].value = device_put_replicated(val, sharding) + state = nnx.from_flat_state(state) + + wan_transformer = nnx.merge(graphdef, state, rest_of_state) + return wan_transformer + + +@nnx.jit(static_argnums=(1,), donate_argnums=(0,)) +def create_sharded_logical_model(model, logical_axis_rules): + graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) + p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) + state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + pspecs = nnx.get_partition_spec(state) + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) + model = nnx.merge(graphdef, sharded_state, rest_of_state) + return model + + +class WanPipeline: + r""" + Pipeline for text-to-video generation using Wan. + + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlaxUniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + low_noise_transformer: WanModel, + high_noise_transformer: WanModel, + vae: AutoencoderKLWan, + vae_cache: AutoencoderKLWanCache, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state: UniPCMultistepSchedulerState, + devices_array: np.array, + mesh: Mesh, + config: HyperParameters, + ): + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer + self.vae = vae + self.vae_cache = vae_cache + self.scheduler = scheduler + self.scheduler_state = scheduler_state + self.devices_array = devices_array + self.mesh = mesh + self.config = config + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.p_run_inference = None + + @classmethod + def load_text_encoder(cls, config: HyperParameters): + text_encoder = UMT5EncoderModel.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="text_encoder", + ) + return text_encoder + + @classmethod + def load_tokenizer(cls, config: HyperParameters): + tokenizer = AutoTokenizer.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="tokenizer", + ) + return tokenizer + + @classmethod + def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): + + def create_model(rngs: nnx.Rngs, config: HyperParameters): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=jnp.float32, + weights_dtype=jnp.float32, + ) + return wan_vae + + # 1. eval shape + p_model_factory = partial(create_model, config=config) + wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) + graphdef, state = nnx.split(wan_vae, nnx.Param) + + # 2. retrieve the state shardings, mapping logical names to mesh axis names. + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) + params = state.to_pure_dict() + state = dict(nnx.to_flat_state(state)) + + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. + # This helps with loading sharded weights directly into the accelerators without fist copying them + # all to one device and then distributing them, thus using low HBM memory. + params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) + for path, val in flax.traverse_util.flatten_dict(params).items(): + sharding = logical_state_sharding[path].value + if config.replicate_vae: + sharding = NamedSharding(mesh, P()) + state[path].value = device_put_replicated(val, sharding) + state = nnx.from_flat_state(state) + + wan_vae = nnx.merge(graphdef, state) + vae_cache = AutoencoderKLWanCache(wan_vae) + return wan_vae, vae_cache + + @classmethod + def get_basic_config(cls, dtype, config: HyperParameters): + rules = [ + qwix.QtRule( + module_path=config.qwix_module_path, + weight_qtype=dtype, + act_qtype=dtype, + op_names=("dot_general", "einsum", "conv_general_dilated"), + ) + ] + return rules + + @classmethod + def get_fp8_config(cls, config: HyperParameters): + """ + fp8 config rules with per-tensor calibration. + FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api): + The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice. + """ + rules = [ + qwix.QtRule( + module_path=config.qwix_module_path, + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config.quantization_calibration_method, + act_calibration_method=config.quantization_calibration_method, + bwd_calibration_method=config.quantization_calibration_method, + op_names=("dot_general", "einsum"), + ), + qwix.QtRule( + module_path=config.qwix_module_path, + weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e4m3fn, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config.quantization_calibration_method, + act_calibration_method=config.quantization_calibration_method, + bwd_calibration_method=config.quantization_calibration_method, + op_names=("conv_general_dilated"), + ), + ] + return rules + + @classmethod + def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: + """Get quantization rules based on the config.""" + if not getattr(config, "use_qwix_quantization", False): + return None + + match config.quantization: + case "int8": + return qwix.QtProvider(cls.get_basic_config(jnp.int8, config)) + case "fp8": + return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config)) + case "fp8_full": + return qwix.QtProvider(cls.get_fp8_config(config)) + return None + + @classmethod + def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): + """Quantizes the transformer model.""" + q_rules = cls.get_qt_provider(config) + if not q_rules: + return model + max_logging.log("Quantizing transformer with Qwix.") + + batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) + latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) + model_inputs = (latents, timesteps, prompt_embeds) + with mesh: + quantized_model = qwix.quantize_model(model, q_rules, *model_inputs) + max_logging.log("Qwix Quantization complete.") + return quantized_model + + @classmethod + def load_transformer( + cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): + with mesh: + wan_transformer = create_sharded_logical_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder + ) + return wan_transformer + + @classmethod + def load_scheduler(cls, config): + scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( + config.pretrained_model_name_or_path, + subfolder="scheduler", + flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p + ) + return scheduler, scheduler_state + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + low_noise_transformer = None + high_noise_transformer = None + tokenizer = None + scheduler = None + scheduler_state = None + text_encoder = None + if not vae_only: + if load_transformer: + with mesh: + low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") + high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") + + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + + scheduler, scheduler_state = cls.load_scheduler(config=config) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + return WanPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config, + ) + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + low_noise_transformer = None + high_noise_transformer = None + tokenizer = None + scheduler = None + scheduler_state = None + text_encoder = None + if not vae_only: + if load_transformer: + with mesh: + low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") + high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") + text_encoder = cls.load_text_encoder(config=config) + tokenizer = cls.load_tokenizer(config=config) + + scheduler, scheduler_state = cls.load_scheduler(config=config) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + pipeline = WanPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=wan_vae, + vae_cache=vae_cache, + scheduler=scheduler, + scheduler_state=scheduler_state, + devices_array=devices_array, + mesh=mesh, + config=config, + ) + + pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) + pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) + return pipeline + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32) + + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + vae_scale_factor_temporal: int, + vae_scale_factor_spatial: int, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_channels_latents: int = 16, + ): + rng = jax.random.key(self.config.seed) + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // vae_scale_factor_spatial, + int(width) // vae_scale_factor_spatial, + ) + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + + return latents + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): + if not vae_only: + if num_frames % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + prompt = [prompt] + + batch_size = len(prompt) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + num_channel_latents = self.low_noise_transformer.config.in_channels + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + vae_scale_factor_temporal=self.vae_scale_factor_temporal, + vae_scale_factor_spatial=self.vae_scale_factor_spatial, + height=height, + width=width, + num_frames=num_frames, + num_channels_latents=num_channel_latents, + ) + + data_sharding = NamedSharding(self.mesh, P()) + # Using global_batch_size_to_train_on so not to create more config variables + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + + latents = jax.device_put(latents, data_sharding) + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + ) + + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(jnp.float32) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] + + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + video = self.video_processor.postprocess_video(video, output_type="np") + return video + + +@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) +def transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance, + guidance_scale, +): + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] + + return noise_pred, latents + +def run_inference( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + def low_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + low_noise_graphdef, low_noise_state, low_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_low + ) + + def high_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + high_noise_graphdef, high_noise_state, high_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_high + ) + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, latents.shape[0]) + use_high_noise = jnp.greater_equal(t, boundary) + + noise_pred, latents = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + (latents, timestep, prompt_embeds) + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py b/src/maxdiffusion/tests/wan_checkpointer2_2_test.py new file mode 100644 index 00000000..8e1fa0be --- /dev/null +++ b/src/maxdiffusion/tests/wan_checkpointer2_2_test.py @@ -0,0 +1,113 @@ +""" + Copyright 2025 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + +import unittest +from unittest.mock import patch, MagicMock + +from maxdiffusion.checkpointing.wan_checkpointer2_2 import WanCheckpointer, WAN_CHECKPOINT + + +class WanCheckpointerTest(unittest.TestCase): + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") + def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + + +if __name__ == "__main__": + unittest.main()