Skip to content

Commit cd16df2

Browse files
committed
removed model_name flag
1 parent f414753 commit cd16df2

File tree

1 file changed

+46
-114
lines changed

1 file changed

+46
-114
lines changed

src/maxdiffusion/generate_wan.py

Lines changed: 46 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,10 @@
2020
from maxdiffusion.pipelines.wan.wan_pipeline2_2 import WanPipeline as WanPipeline2_2
2121
from maxdiffusion import pyconfig, max_logging, max_utils
2222
from absl import app
23-
from absl import flags
2423
from maxdiffusion.utils import export_to_video
2524
from google.cloud import storage
2625
import flax
2726

28-
_MODEL_NAME = flags.DEFINE_enum(
29-
"model_name",
30-
default="wan2.1",
31-
enum_values=["wan2.1", "wan2.2"],
32-
help="The model version to run (wan2.1 or wan2.2). This determines the base config file.",
33-
)
34-
35-
CONFIG_BASE_DIR = "src/maxdiffusion/configs"
36-
MODEL_CONFIG_MAP = {
37-
"wan2.1": "base_wan_14b.yml",
38-
"wan2.2": "base_wan_27b.yml",
39-
}
4027

4128
def upload_video_to_gcs(output_dir: str, video_path: str):
4229
"""
@@ -77,18 +64,26 @@ def delete_file(file_path: str):
7764

7865
jax.config.update("jax_use_shardy_partitioner", True)
7966

67+
def get_pipeline(model_name: str):
68+
if model_name == "wan2.1":
69+
return importlib.import_module("maxdiffusion.pipelines.wan.wan_pipeline")
70+
elif model_name == "wan2.2":
71+
return importlib.import_module("maxdiffusion.pipelines.wan.wan_pipeline2_2")
72+
else:
73+
raise ValueError(f"Unsupported model_name in config: {model_name}")
8074

81-
def inference_generate_video(config, pipeline, filename_prefix=""):
82-
s0 = time.perf_counter()
83-
prompt = [config.prompt] * config.global_batch_size_to_train_on
84-
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
75+
def get_checkpointer(model_name: str):
76+
if model_name == "wan2.1":
77+
return importlib.import_module("maxdiffusion.checkpointing.wan_checkpointer")
78+
elif model_name == "wan2.2":
79+
return importlib.import_module("maxdiffusion.checkpointing.wan_checkpointer2_2")
80+
else:
81+
raise ValueError(f"Unsupported model_name in config: {model_name}")
8582

86-
max_logging.log(
87-
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}"
88-
)
89-
model_key = _MODEL_NAME.value
83+
def call_pipeline(config, pipeline, prompt, negative_prompt):
84+
model_key = config.model_name
9085
if model_key == "wan2.1":
91-
videos = pipeline(
86+
return pipeline(
9287
prompt=prompt,
9388
negative_prompt=negative_prompt,
9489
height=config.height,
@@ -98,7 +93,7 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
9893
guidance_scale=config.guidance_scale,
9994
)
10095
elif model_key == "wan2.2":
101-
videos = pipeline(
96+
return pipeline(
10297
prompt=prompt,
10398
negative_prompt=negative_prompt,
10499
height=config.height,
@@ -109,6 +104,20 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
109104
guidance_scale_high=config.guidance_scale_high,
110105
boundary=config.boundary_timestep,
111106
)
107+
else:
108+
raise ValueError(f"Unsupported model_name in config: {model_key}")
109+
110+
111+
def inference_generate_video(config, pipeline, filename_prefix=""):
112+
s0 = time.perf_counter()
113+
prompt = [config.prompt] * config.global_batch_size_to_train_on
114+
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
115+
116+
max_logging.log(
117+
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}, video: {filename_prefix}"
118+
)
119+
120+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
112121

113122
max_logging.log(f"video {filename_prefix}, compile time: {(time.perf_counter() - s0)}")
114123
for i in range(len(videos)):
@@ -123,20 +132,18 @@ def inference_generate_video(config, pipeline, filename_prefix=""):
123132

124133
def run(config, pipeline=None, filename_prefix=""):
125134
print("seed: ", config.seed)
126-
model_key = _MODEL_NAME.value
135+
model_key = config.model_name
127136

128-
if model_key == "wan2.1":
129-
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
130-
elif model_key == "wan2.2":
131-
from maxdiffusion.checkpointing.wan_checkpointer2_2 import WanCheckpointer
137+
checkpointer_lib = get_checkpointer(model_key)
138+
WanCheckpointer = checkpointer_lib.WanCheckpointer
132139

133140
checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT")
134-
pipeline, opt_state, step = checkpoint_loader.load_checkpoint()
141+
pipeline, _, _ = checkpoint_loader.load_checkpoint()
142+
135143
if pipeline is None:
136-
if model_key == "wan2.1":
137-
pipeline = WanPipeline2_1.from_pretrained(config)
138-
elif model_key == "wan2.2":
139-
pipeline = WanPipeline2_2.from_pretrained(config)
144+
pipeline_lib = get_pipeline(model_key)
145+
WanPipeline = pipeline_lib.WanPipeline
146+
pipeline = WanPipeline.from_pretrained(config)
140147
s0 = time.perf_counter()
141148

142149
# Using global_batch_size_to_train_on so not to create more config variables
@@ -146,28 +153,8 @@ def run(config, pipeline=None, filename_prefix=""):
146153
max_logging.log(
147154
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
148155
)
149-
if model_key == "wan2.1":
150-
videos = pipeline(
151-
prompt=prompt,
152-
negative_prompt=negative_prompt,
153-
height=config.height,
154-
width=config.width,
155-
num_frames=config.num_frames,
156-
num_inference_steps=config.num_inference_steps,
157-
guidance_scale=config.guidance_scale,
158-
)
159-
elif model_key == "wan2.2":
160-
videos = pipeline(
161-
prompt=prompt,
162-
negative_prompt=negative_prompt,
163-
height=config.height,
164-
width=config.width,
165-
num_frames=config.num_frames,
166-
num_inference_steps=config.num_inference_steps,
167-
guidance_scale_low=config.guidance_scale_low,
168-
guidance_scale_high=config.guidance_scale_high,
169-
boundary=config.boundary_timestep,
170-
)
156+
157+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
171158

172159
print("compile time: ", (time.perf_counter() - s0))
173160
saved_video_path = []
@@ -179,78 +166,23 @@ def run(config, pipeline=None, filename_prefix=""):
179166
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
180167

181168
s0 = time.perf_counter()
182-
if model_key == "wan2.1":
183-
videos = pipeline(
184-
prompt=prompt,
185-
negative_prompt=negative_prompt,
186-
height=config.height,
187-
width=config.width,
188-
num_frames=config.num_frames,
189-
num_inference_steps=config.num_inference_steps,
190-
guidance_scale=config.guidance_scale,
191-
)
192-
elif model_key == "wan2.2":
193-
videos = pipeline(
194-
prompt=prompt,
195-
negative_prompt=negative_prompt,
196-
height=config.height,
197-
width=config.width,
198-
num_frames=config.num_frames,
199-
num_inference_steps=config.num_inference_steps,
200-
guidance_scale_low=config.guidance_scale_low,
201-
guidance_scale_high=config.guidance_scale_high,
202-
boundary=config.boundary_timestep,
203-
)
169+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
204170
print("generation time: ", (time.perf_counter() - s0))
205171

206172
s0 = time.perf_counter()
207173
if config.enable_profiler:
208174
max_utils.activate_profiler(config)
209-
if model_key == "wan2.1":
210-
videos = pipeline(
211-
prompt=prompt,
212-
negative_prompt=negative_prompt,
213-
height=config.height,
214-
width=config.width,
215-
num_frames=config.num_frames,
216-
num_inference_steps=config.num_inference_steps,
217-
guidance_scale=config.guidance_scale,
218-
)
219-
elif model_key == "wan2.2":
220-
videos = pipeline(
221-
prompt=prompt,
222-
negative_prompt=negative_prompt,
223-
height=config.height,
224-
width=config.width,
225-
num_frames=config.num_frames,
226-
num_inference_steps=config.num_inference_steps,
227-
guidance_scale_low=config.guidance_scale_low,
228-
guidance_scale_high=config.guidance_scale_high,
229-
boundary=config.boundary_timestep,
230-
)
175+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
231176
max_utils.deactivate_profiler(config)
232177
print("generation time: ", (time.perf_counter() - s0))
233178
return saved_video_path
234179

235180

236181
def main(argv: Sequence[str]) -> None:
237-
# Get the model name from the flag
238-
model_key = _MODEL_NAME.value
239-
config_filename = MODEL_CONFIG_MAP[model_key]
240-
selected_yaml_path = os.path.join(CONFIG_BASE_DIR, config_filename)
241-
242-
max_logging.log(f"Using model: {model_key}, loading base config: {selected_yaml_path}")
243-
244-
# Construct argv for pyconfig.initialize
245-
# argv[0] is the program name.
246-
# Insert the selected YAML path at index 1.
247-
# The rest of argv (argv[1:]) are the overrides.
248-
argv_for_pyconfig = list(argv[:1]) + [selected_yaml_path] + list(argv[1:])
249-
250-
pyconfig.initialize(argv_for_pyconfig)
182+
pyconfig.initialize(argv)
251183
flax.config.update("flax_always_shard_variable", False)
252184
run(pyconfig.config)
253185

254186

255187
if __name__ == "__main__":
256-
app.run(main)
188+
app.run(main)

0 commit comments

Comments
 (0)