Skip to content

Commit 9621edb

Browse files
authored
feat(diffusers): add support for wan2.2 (#6153)
* feat(diffusers): add support for wan2.2 Signed-off-by: Ettore Di Giacinto <[email protected]> * chore(ci): use ttl.sh for PRs Signed-off-by: Ettore Di Giacinto <[email protected]> * Add ftfy deps Signed-off-by: Ettore Di Giacinto <[email protected]> * Revert "chore(ci): use ttl.sh for PRs" This reverts commit c9fc3ec. * Simplify Signed-off-by: Ettore Di Giacinto <[email protected]> * chore: do not pin torch/torchvision on cuda12 Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 7ce92f0 commit 9621edb

File tree

15 files changed

+195
-73
lines changed

15 files changed

+195
-73
lines changed

backend/backend.proto

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,17 @@ message GenerateImageRequest {
312312

313313
message GenerateVideoRequest {
314314
string prompt = 1;
315-
string start_image = 2; // Path or base64 encoded image for the start frame
316-
string end_image = 3; // Path or base64 encoded image for the end frame
317-
int32 width = 4;
318-
int32 height = 5;
319-
int32 num_frames = 6; // Number of frames to generate
320-
int32 fps = 7; // Frames per second
321-
int32 seed = 8;
322-
float cfg_scale = 9; // Classifier-free guidance scale
323-
string dst = 10; // Output path for the generated video
315+
string negative_prompt = 2; // Negative prompt for video generation
316+
string start_image = 3; // Path or base64 encoded image for the start frame
317+
string end_image = 4; // Path or base64 encoded image for the end frame
318+
int32 width = 5;
319+
int32 height = 6;
320+
int32 num_frames = 7; // Number of frames to generate
321+
int32 fps = 8; // Frames per second
322+
int32 seed = 9;
323+
float cfg_scale = 10; // Classifier-free guidance scale
324+
int32 step = 11; // Number of inference steps
325+
string dst = 12; // Output path for the generated video
324326
}
325327

326328
message TTSRequest {

backend/python/diffusers/backend.py

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import grpc
1919

2020
from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
21-
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline
21+
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline
2222
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline
2323
from diffusers.pipelines.stable_diffusion import safety_checker
2424
from diffusers.utils import load_image, export_to_video
@@ -72,13 +72,6 @@ def is_float(s):
7272
except ValueError:
7373
return False
7474

75-
def is_int(s):
76-
try:
77-
int(s)
78-
return True
79-
except ValueError:
80-
return False
81-
8275
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
8376
# Credits to https://github.com/neggles
8477
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
@@ -184,9 +177,10 @@ def LoadModel(self, request, context):
184177
key, value = opt.split(":")
185178
# if value is a number, convert it to the appropriate type
186179
if is_float(value):
187-
value = float(value)
188-
elif is_int(value):
189-
value = int(value)
180+
if value.is_integer():
181+
value = int(value)
182+
else:
183+
value = float(value)
190184
self.options[key] = value
191185

192186
# From options, extract if present "torch_dtype" and set it to the appropriate type
@@ -334,6 +328,32 @@ def LoadModel(self, request, context):
334328
torch_dtype=torch.bfloat16)
335329
self.pipe.vae.to(torch.bfloat16)
336330
self.pipe.text_encoder.to(torch.bfloat16)
331+
elif request.PipelineType == "WanPipeline":
332+
# WAN2.2 pipeline requires special VAE handling
333+
vae = AutoencoderKLWan.from_pretrained(
334+
request.Model,
335+
subfolder="vae",
336+
torch_dtype=torch.float32
337+
)
338+
self.pipe = WanPipeline.from_pretrained(
339+
request.Model,
340+
vae=vae,
341+
torch_dtype=torchType
342+
)
343+
self.txt2vid = True # WAN2.2 is a text-to-video pipeline
344+
elif request.PipelineType == "WanImageToVideoPipeline":
345+
# WAN2.2 image-to-video pipeline
346+
vae = AutoencoderKLWan.from_pretrained(
347+
request.Model,
348+
subfolder="vae",
349+
torch_dtype=torch.float32
350+
)
351+
self.pipe = WanImageToVideoPipeline.from_pretrained(
352+
request.Model,
353+
vae=vae,
354+
torch_dtype=torchType
355+
)
356+
self.img2vid = True # WAN2.2 image-to-video pipeline
337357

338358
if CLIPSKIP and request.CLIPSkip != 0:
339359
self.clip_skip = request.CLIPSkip
@@ -575,6 +595,96 @@ def GenerateImage(self, request, context):
575595

576596
return backend_pb2.Result(message="Media generated", success=True)
577597

598+
def GenerateVideo(self, request, context):
599+
try:
600+
prompt = request.prompt
601+
if not prompt:
602+
return backend_pb2.Result(success=False, message="No prompt provided for video generation")
603+
604+
# Set default values from request or use defaults
605+
num_frames = request.num_frames if request.num_frames > 0 else 81
606+
fps = request.fps if request.fps > 0 else 16
607+
cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0
608+
num_inference_steps = request.step if request.step > 0 else 40
609+
610+
# Prepare generation parameters
611+
kwargs = {
612+
"prompt": prompt,
613+
"negative_prompt": request.negative_prompt if request.negative_prompt else "",
614+
"height": request.height if request.height > 0 else 720,
615+
"width": request.width if request.width > 0 else 1280,
616+
"num_frames": num_frames,
617+
"guidance_scale": cfg_scale,
618+
"num_inference_steps": num_inference_steps,
619+
}
620+
621+
# Add custom options from self.options (including guidance_scale_2 if specified)
622+
kwargs.update(self.options)
623+
624+
# Set seed if provided
625+
if request.seed > 0:
626+
kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed)
627+
628+
# Handle start and end images for video generation
629+
if request.start_image:
630+
kwargs["start_image"] = load_image(request.start_image)
631+
if request.end_image:
632+
kwargs["end_image"] = load_image(request.end_image)
633+
634+
print(f"Generating video with {kwargs=}", file=sys.stderr)
635+
636+
# Generate video frames based on pipeline type
637+
if self.PipelineType == "WanPipeline":
638+
# WAN2.2 text-to-video generation
639+
output = self.pipe(**kwargs)
640+
frames = output.frames[0] # WAN2.2 returns frames in this format
641+
elif self.PipelineType == "WanImageToVideoPipeline":
642+
# WAN2.2 image-to-video generation
643+
if request.start_image:
644+
# Load and resize the input image according to WAN2.2 requirements
645+
image = load_image(request.start_image)
646+
# Use request dimensions or defaults, but respect WAN2.2 constraints
647+
request_height = request.height if request.height > 0 else 480
648+
request_width = request.width if request.width > 0 else 832
649+
max_area = request_height * request_width
650+
aspect_ratio = image.height / image.width
651+
mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
652+
height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value
653+
width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value
654+
image = image.resize((width, height))
655+
kwargs["image"] = image
656+
kwargs["height"] = height
657+
kwargs["width"] = width
658+
659+
output = self.pipe(**kwargs)
660+
frames = output.frames[0]
661+
elif self.img2vid:
662+
# Generic image-to-video generation
663+
if request.start_image:
664+
image = load_image(request.start_image)
665+
image = image.resize((request.width if request.width > 0 else 1024,
666+
request.height if request.height > 0 else 576))
667+
kwargs["image"] = image
668+
669+
output = self.pipe(**kwargs)
670+
frames = output.frames[0]
671+
elif self.txt2vid:
672+
# Generic text-to-video generation
673+
output = self.pipe(**kwargs)
674+
frames = output.frames[0]
675+
else:
676+
return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation")
677+
678+
# Export video
679+
export_to_video(frames, request.dst, fps=fps)
680+
681+
return backend_pb2.Result(message="Video generated successfully", success=True)
682+
683+
except Exception as err:
684+
print(f"Error generating video: {err}", file=sys.stderr)
685+
traceback.print_exc()
686+
return backend_pb2.Result(success=False, message=f"Error generating video: {err}")
687+
578688

579689
def serve(address):
580690
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),

backend/python/diffusers/requirements-cpu.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ compel
88
peft
99
sentencepiece
1010
torch==2.7.1
11-
optimum-quanto
11+
optimum-quanto
12+
ftfy
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
--extra-index-url https://download.pytorch.org/whl/cu118
2-
torch==2.7.1+cu118
3-
torchvision==0.22.1+cu118
42
git+https://github.com/huggingface/diffusers
53
opencv-python
64
transformers
5+
torchvision==0.22.1
76
accelerate
87
compel
98
peft
109
sentencepiece
11-
optimum-quanto
10+
torch==2.7.1
11+
optimum-quanto
12+
ftfy
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
torch==2.7.1
2-
torchvision==0.22.1
1+
--extra-index-url https://download.pytorch.org/whl/cu121
32
git+https://github.com/huggingface/diffusers
43
opencv-python
54
transformers
5+
torchvision
66
accelerate
77
compel
88
peft
99
sentencepiece
10-
optimum-quanto
10+
torch
11+
ftfy

backend/python/diffusers/requirements-hipblas.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ accelerate
88
compel
99
peft
1010
sentencepiece
11-
optimum-quanto
11+
optimum-quanto
12+
ftfy

backend/python/diffusers/requirements-intel.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ accelerate
1212
compel
1313
peft
1414
sentencepiece
15-
optimum-quanto
15+
optimum-quanto
16+
ftfy

backend/python/diffusers/requirements-l4t.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ peft
88
optimum-quanto
99
numpy<2
1010
sentencepiece
11-
torchvision
11+
torchvision
12+
ftfy

backend/python/diffusers/requirements-mps.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ accelerate
77
compel
88
peft
99
sentencepiece
10-
optimum-quanto
10+
optimum-quanto
11+
ftfy

backend/python/mlx-audio/backend.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,6 @@ def _is_float(self, s):
4040
except ValueError:
4141
return False
4242

43-
def _is_int(self, s):
44-
"""Check if a string can be converted to int."""
45-
try:
46-
int(s)
47-
return True
48-
except ValueError:
49-
return False
50-
5143
def Health(self, request, context):
5244
"""
5345
Returns a health check message.
@@ -89,9 +81,10 @@ async def LoadModel(self, request, context):
8981

9082
# Convert numeric values to appropriate types
9183
if self._is_float(value):
92-
value = float(value)
93-
elif self._is_int(value):
94-
value = int(value)
84+
if float(value).is_integer():
85+
value = int(value)
86+
else:
87+
value = float(value)
9588
elif value.lower() in ["true", "false"]:
9689
value = value.lower() == "true"
9790

0 commit comments

Comments
 (0)