2020from maxdiffusion .pipelines .wan .wan_pipeline2_2 import WanPipeline as WanPipeline2_2
2121from maxdiffusion import pyconfig , max_logging , max_utils
2222from absl import app
23- from absl import flags
2423from maxdiffusion .utils import export_to_video
2524from google .cloud import storage
2625import flax
2726
28- _MODEL_NAME = flags .DEFINE_enum (
29- "model_name" ,
30- default = "wan2.1" ,
31- enum_values = ["wan2.1" , "wan2.2" ],
32- help = "The model version to run (wan2.1 or wan2.2). This determines the base config file." ,
33- )
34-
35- CONFIG_BASE_DIR = "src/maxdiffusion/configs"
36- MODEL_CONFIG_MAP = {
37- "wan2.1" : "base_wan_14b.yml" ,
38- "wan2.2" : "base_wan_27b.yml" ,
39- }
4027
4128def upload_video_to_gcs (output_dir : str , video_path : str ):
4229 """
@@ -77,18 +64,26 @@ def delete_file(file_path: str):
7764
7865jax .config .update ("jax_use_shardy_partitioner" , True )
7966
67+ def get_pipeline (model_name : str ):
68+ if model_name == "wan2.1" :
69+ return importlib .import_module ("maxdiffusion.pipelines.wan.wan_pipeline" )
70+ elif model_name == "wan2.2" :
71+ return importlib .import_module ("maxdiffusion.pipelines.wan.wan_pipeline2_2" )
72+ else :
73+ raise ValueError (f"Unsupported model_name in config: { model_name } " )
8074
81- def inference_generate_video (config , pipeline , filename_prefix = "" ):
82- s0 = time .perf_counter ()
83- prompt = [config .prompt ] * config .global_batch_size_to_train_on
84- negative_prompt = [config .negative_prompt ] * config .global_batch_size_to_train_on
75+ def get_checkpointer (model_name : str ):
76+ if model_name == "wan2.1" :
77+ return importlib .import_module ("maxdiffusion.checkpointing.wan_checkpointer" )
78+ elif model_name == "wan2.2" :
79+ return importlib .import_module ("maxdiffusion.checkpointing.wan_checkpointer2_2" )
80+ else :
81+ raise ValueError (f"Unsupported model_name in config: { model_name } " )
8582
86- max_logging .log (
87- f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } , video: { filename_prefix } "
88- )
89- model_key = _MODEL_NAME .value
83+ def call_pipeline (config , pipeline , prompt , negative_prompt ):
84+ model_key = config .model_name
9085 if model_key == "wan2.1" :
91- videos = pipeline (
86+ return pipeline (
9287 prompt = prompt ,
9388 negative_prompt = negative_prompt ,
9489 height = config .height ,
@@ -98,7 +93,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
9893 guidance_scale = config .guidance_scale ,
9994 )
10095 elif model_key == "wan2.2" :
101- videos = pipeline (
96+ return pipeline (
10297 prompt = prompt ,
10398 negative_prompt = negative_prompt ,
10499 height = config .height ,
@@ -109,6 +104,20 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
109104 guidance_scale_high = config .guidance_scale_high ,
110105 boundary = config .boundary_timestep ,
111106 )
107+ else :
108+ raise ValueError (f"Unsupported model_name in config: { model_key } " )
109+
110+
111+ def inference_generate_video (config , pipeline , filename_prefix = "" ):
112+ s0 = time .perf_counter ()
113+ prompt = [config .prompt ] * config .global_batch_size_to_train_on
114+ negative_prompt = [config .negative_prompt ] * config .global_batch_size_to_train_on
115+
116+ max_logging .log (
117+ f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } , video: { filename_prefix } "
118+ )
119+
120+ videos = call_pipeline (config , pipeline , prompt , negative_prompt )
112121
113122 max_logging .log (f"video { filename_prefix } , compile time: { (time .perf_counter () - s0 )} " )
114123 for i in range (len (videos )):
@@ -123,20 +132,18 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
123132
124133def run (config , pipeline = None , filename_prefix = "" ):
125134 print ("seed: " , config .seed )
126- model_key = _MODEL_NAME . value
135+ model_key = config . model_name
127136
128- if model_key == "wan2.1" :
129- from maxdiffusion .checkpointing .wan_checkpointer import WanCheckpointer
130- elif model_key == "wan2.2" :
131- from maxdiffusion .checkpointing .wan_checkpointer2_2 import WanCheckpointer
137+ checkpointer_lib = get_checkpointer (model_key )
138+ WanCheckpointer = checkpointer_lib .WanCheckpointer
132139
133140 checkpoint_loader = WanCheckpointer (config , "WAN_CHECKPOINT" )
134- pipeline , opt_state , step = checkpoint_loader .load_checkpoint ()
141+ pipeline , _ , _ = checkpoint_loader .load_checkpoint ()
142+
135143 if pipeline is None :
136- if model_key == "wan2.1" :
137- pipeline = WanPipeline2_1 .from_pretrained (config )
138- elif model_key == "wan2.2" :
139- pipeline = WanPipeline2_2 .from_pretrained (config )
144+ pipeline_lib = get_pipeline (model_key )
145+ WanPipeline = pipeline_lib .WanPipeline
146+ pipeline = WanPipeline .from_pretrained (config )
140147 s0 = time .perf_counter ()
141148
142149 # Using global_batch_size_to_train_on so not to create more config variables
@@ -146,28 +153,8 @@ def run(config, pipeline=None, filename_prefix=""):
146153 max_logging .log (
147154 f"Num steps: { config .num_inference_steps } , height: { config .height } , width: { config .width } , frames: { config .num_frames } "
148155 )
149- if model_key == "wan2.1" :
150- videos = pipeline (
151- prompt = prompt ,
152- negative_prompt = negative_prompt ,
153- height = config .height ,
154- width = config .width ,
155- num_frames = config .num_frames ,
156- num_inference_steps = config .num_inference_steps ,
157- guidance_scale = config .guidance_scale ,
158- )
159- elif model_key == "wan2.2" :
160- videos = pipeline (
161- prompt = prompt ,
162- negative_prompt = negative_prompt ,
163- height = config .height ,
164- width = config .width ,
165- num_frames = config .num_frames ,
166- num_inference_steps = config .num_inference_steps ,
167- guidance_scale_low = config .guidance_scale_low ,
168- guidance_scale_high = config .guidance_scale_high ,
169- boundary = config .boundary_timestep ,
170- )
156+
157+ videos = call_pipeline (config , pipeline , prompt , negative_prompt )
171158
172159 print ("compile time: " , (time .perf_counter () - s0 ))
173160 saved_video_path = []
@@ -179,78 +166,23 @@ def run(config, pipeline=None, filename_prefix=""):
179166 upload_video_to_gcs (os .path .join (config .output_dir , config .run_name ), video_path )
180167
181168 s0 = time .perf_counter ()
182- if model_key == "wan2.1" :
183- videos = pipeline (
184- prompt = prompt ,
185- negative_prompt = negative_prompt ,
186- height = config .height ,
187- width = config .width ,
188- num_frames = config .num_frames ,
189- num_inference_steps = config .num_inference_steps ,
190- guidance_scale = config .guidance_scale ,
191- )
192- elif model_key == "wan2.2" :
193- videos = pipeline (
194- prompt = prompt ,
195- negative_prompt = negative_prompt ,
196- height = config .height ,
197- width = config .width ,
198- num_frames = config .num_frames ,
199- num_inference_steps = config .num_inference_steps ,
200- guidance_scale_low = config .guidance_scale_low ,
201- guidance_scale_high = config .guidance_scale_high ,
202- boundary = config .boundary_timestep ,
203- )
169+ videos = call_pipeline (config , pipeline , prompt , negative_prompt )
204170 print ("generation time: " , (time .perf_counter () - s0 ))
205171
206172 s0 = time .perf_counter ()
207173 if config .enable_profiler :
208174 max_utils .activate_profiler (config )
209- if model_key == "wan2.1" :
210- videos = pipeline (
211- prompt = prompt ,
212- negative_prompt = negative_prompt ,
213- height = config .height ,
214- width = config .width ,
215- num_frames = config .num_frames ,
216- num_inference_steps = config .num_inference_steps ,
217- guidance_scale = config .guidance_scale ,
218- )
219- elif model_key == "wan2.2" :
220- videos = pipeline (
221- prompt = prompt ,
222- negative_prompt = negative_prompt ,
223- height = config .height ,
224- width = config .width ,
225- num_frames = config .num_frames ,
226- num_inference_steps = config .num_inference_steps ,
227- guidance_scale_low = config .guidance_scale_low ,
228- guidance_scale_high = config .guidance_scale_high ,
229- boundary = config .boundary_timestep ,
230- )
175+ videos = call_pipeline (config , pipeline , prompt , negative_prompt )
231176 max_utils .deactivate_profiler (config )
232177 print ("generation time: " , (time .perf_counter () - s0 ))
233178 return saved_video_path
234179
235180
236181def main (argv : Sequence [str ]) -> None :
237- # Get the model name from the flag
238- model_key = _MODEL_NAME .value
239- config_filename = MODEL_CONFIG_MAP [model_key ]
240- selected_yaml_path = os .path .join (CONFIG_BASE_DIR , config_filename )
241-
242- max_logging .log (f"Using model: { model_key } , loading base config: { selected_yaml_path } " )
243-
244- # Construct argv for pyconfig.initialize
245- # argv[0] is the program name.
246- # Insert the selected YAML path at index 1.
247- # The rest of argv (argv[1:]) are the overrides.
248- argv_for_pyconfig = list (argv [:1 ]) + [selected_yaml_path ] + list (argv [1 :])
249-
250- pyconfig .initialize (argv_for_pyconfig )
182+ pyconfig .initialize (argv )
251183 flax .config .update ("flax_always_shard_variable" , False )
252184 run (pyconfig .config )
253185
254186
255187if __name__ == "__main__" :
256- app .run (main )
188+ app .run (main )
0 commit comments