Skip to content

Commit 95586a6

Browse files
committed
new line added
1 parent 9b44b5e commit 95586a6

File tree

4 files changed

+4
-6
lines changed

4 files changed

+4
-6
lines changed

src/maxdiffusion/checkpointing/wan_checkpointer2_2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ..pipelines.wan.wan_pipeline2_2 import WanPipeline
2525
from .. import max_logging, max_utils
2626
import orbax.checkpoint as ocp
27-
from etils import epath
2827

2928
WAN_CHECKPOINT = "WAN_CHECKPOINT"
3029

@@ -206,4 +205,4 @@ def config_to_json(model_or_config):
206205

207206
# Save the checkpoint
208207
self.checkpoint_manager.save(train_step, args=save_args)
209-
max_logging.log(f"Checkpoint for step {train_step} saved.")
208+
max_logging.log(f"Checkpoint for step {train_step} saved.")

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import os
1919
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline as WanPipeline2_1
2020
from maxdiffusion.pipelines.wan.wan_pipeline2_2 import WanPipeline as WanPipeline2_2
21-
from functools import partial
2221
from maxdiffusion import pyconfig, max_logging, max_utils
2322
from absl import app
2423
from absl import flags
@@ -143,7 +142,7 @@ def run(config, pipeline=None, filename_prefix=""):
143142
# Using global_batch_size_to_train_on so not to create more config variables
144143
prompt = [config.prompt] * config.global_batch_size_to_train_on
145144
negative_prompt = [config.negative_prompt] * config.global_batch_size_to_train_on
146-
145+
147146
max_logging.log(
148147
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
149148
)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_
387387
transformer = cls.load_transformer(
388388
devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer"
389389
)
390-
390+
391391
text_encoder = cls.load_text_encoder(config=config)
392392
tokenizer = cls.load_tokenizer(config=config)
393393

src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,4 +722,4 @@ def high_noise_branch(operands):
722722
)
723723

724724
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
725-
return latents
725+
return latents

0 commit comments

Comments
 (0)