|
18 | 18 | import grpc |
19 | 19 |
|
20 | 20 | from diffusers import SanaPipeline, StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \ |
21 | | - EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline |
| 21 | + EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, QwenImageEditPipeline, AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline |
22 | 22 | from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline, Lumina2Text2ImgPipeline |
23 | 23 | from diffusers.pipelines.stable_diffusion import safety_checker |
24 | 24 | from diffusers.utils import load_image, export_to_video |
@@ -72,13 +72,6 @@ def is_float(s): |
72 | 72 | except ValueError: |
73 | 73 | return False |
74 | 74 |
|
75 | | -def is_int(s): |
76 | | - try: |
77 | | - int(s) |
78 | | - return True |
79 | | - except ValueError: |
80 | | - return False |
81 | | - |
82 | 75 | # The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39 |
83 | 76 | # Credits to https://github.com/neggles |
84 | 77 | # See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111 |
@@ -184,9 +177,10 @@ def LoadModel(self, request, context): |
184 | 177 | key, value = opt.split(":") |
185 | 178 | # if value is a number, convert it to the appropriate type |
186 | 179 | if is_float(value): |
187 | | - value = float(value) |
188 | | - elif is_int(value): |
189 | | - value = int(value) |
| 180 | + if value.is_integer(): |
| 181 | + value = int(value) |
| 182 | + else: |
| 183 | + value = float(value) |
190 | 184 | self.options[key] = value |
191 | 185 |
|
192 | 186 | # From options, extract if present "torch_dtype" and set it to the appropriate type |
@@ -334,6 +328,32 @@ def LoadModel(self, request, context): |
334 | 328 | torch_dtype=torch.bfloat16) |
335 | 329 | self.pipe.vae.to(torch.bfloat16) |
336 | 330 | self.pipe.text_encoder.to(torch.bfloat16) |
| 331 | + elif request.PipelineType == "WanPipeline": |
| 332 | + # WAN2.2 pipeline requires special VAE handling |
| 333 | + vae = AutoencoderKLWan.from_pretrained( |
| 334 | + request.Model, |
| 335 | + subfolder="vae", |
| 336 | + torch_dtype=torch.float32 |
| 337 | + ) |
| 338 | + self.pipe = WanPipeline.from_pretrained( |
| 339 | + request.Model, |
| 340 | + vae=vae, |
| 341 | + torch_dtype=torchType |
| 342 | + ) |
| 343 | + self.txt2vid = True # WAN2.2 is a text-to-video pipeline |
| 344 | + elif request.PipelineType == "WanImageToVideoPipeline": |
| 345 | + # WAN2.2 image-to-video pipeline |
| 346 | + vae = AutoencoderKLWan.from_pretrained( |
| 347 | + request.Model, |
| 348 | + subfolder="vae", |
| 349 | + torch_dtype=torch.float32 |
| 350 | + ) |
| 351 | + self.pipe = WanImageToVideoPipeline.from_pretrained( |
| 352 | + request.Model, |
| 353 | + vae=vae, |
| 354 | + torch_dtype=torchType |
| 355 | + ) |
| 356 | + self.img2vid = True # WAN2.2 image-to-video pipeline |
337 | 357 |
|
338 | 358 | if CLIPSKIP and request.CLIPSkip != 0: |
339 | 359 | self.clip_skip = request.CLIPSkip |
@@ -575,6 +595,96 @@ def GenerateImage(self, request, context): |
575 | 595 |
|
576 | 596 | return backend_pb2.Result(message="Media generated", success=True) |
577 | 597 |
|
| 598 | + def GenerateVideo(self, request, context): |
| 599 | + try: |
| 600 | + prompt = request.prompt |
| 601 | + if not prompt: |
| 602 | + return backend_pb2.Result(success=False, message="No prompt provided for video generation") |
| 603 | + |
| 604 | + # Set default values from request or use defaults |
| 605 | + num_frames = request.num_frames if request.num_frames > 0 else 81 |
| 606 | + fps = request.fps if request.fps > 0 else 16 |
| 607 | + cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 |
| 608 | + num_inference_steps = request.step if request.step > 0 else 40 |
| 609 | + |
| 610 | + # Prepare generation parameters |
| 611 | + kwargs = { |
| 612 | + "prompt": prompt, |
| 613 | + "negative_prompt": request.negative_prompt if request.negative_prompt else "", |
| 614 | + "height": request.height if request.height > 0 else 720, |
| 615 | + "width": request.width if request.width > 0 else 1280, |
| 616 | + "num_frames": num_frames, |
| 617 | + "guidance_scale": cfg_scale, |
| 618 | + "num_inference_steps": num_inference_steps, |
| 619 | + } |
| 620 | + |
| 621 | + # Add custom options from self.options (including guidance_scale_2 if specified) |
| 622 | + kwargs.update(self.options) |
| 623 | + |
| 624 | + # Set seed if provided |
| 625 | + if request.seed > 0: |
| 626 | + kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed) |
| 627 | + |
| 628 | + # Handle start and end images for video generation |
| 629 | + if request.start_image: |
| 630 | + kwargs["start_image"] = load_image(request.start_image) |
| 631 | + if request.end_image: |
| 632 | + kwargs["end_image"] = load_image(request.end_image) |
| 633 | + |
| 634 | + print(f"Generating video with {kwargs=}", file=sys.stderr) |
| 635 | + |
| 636 | + # Generate video frames based on pipeline type |
| 637 | + if self.PipelineType == "WanPipeline": |
| 638 | + # WAN2.2 text-to-video generation |
| 639 | + output = self.pipe(**kwargs) |
| 640 | + frames = output.frames[0] # WAN2.2 returns frames in this format |
| 641 | + elif self.PipelineType == "WanImageToVideoPipeline": |
| 642 | + # WAN2.2 image-to-video generation |
| 643 | + if request.start_image: |
| 644 | + # Load and resize the input image according to WAN2.2 requirements |
| 645 | + image = load_image(request.start_image) |
| 646 | + # Use request dimensions or defaults, but respect WAN2.2 constraints |
| 647 | + request_height = request.height if request.height > 0 else 480 |
| 648 | + request_width = request.width if request.width > 0 else 832 |
| 649 | + max_area = request_height * request_width |
| 650 | + aspect_ratio = image.height / image.width |
| 651 | + mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] |
| 652 | + height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value |
| 653 | + width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value |
| 654 | + image = image.resize((width, height)) |
| 655 | + kwargs["image"] = image |
| 656 | + kwargs["height"] = height |
| 657 | + kwargs["width"] = width |
| 658 | + |
| 659 | + output = self.pipe(**kwargs) |
| 660 | + frames = output.frames[0] |
| 661 | + elif self.img2vid: |
| 662 | + # Generic image-to-video generation |
| 663 | + if request.start_image: |
| 664 | + image = load_image(request.start_image) |
| 665 | + image = image.resize((request.width if request.width > 0 else 1024, |
| 666 | + request.height if request.height > 0 else 576)) |
| 667 | + kwargs["image"] = image |
| 668 | + |
| 669 | + output = self.pipe(**kwargs) |
| 670 | + frames = output.frames[0] |
| 671 | + elif self.txt2vid: |
| 672 | + # Generic text-to-video generation |
| 673 | + output = self.pipe(**kwargs) |
| 674 | + frames = output.frames[0] |
| 675 | + else: |
| 676 | + return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation") |
| 677 | + |
| 678 | + # Export video |
| 679 | + export_to_video(frames, request.dst, fps=fps) |
| 680 | + |
| 681 | + return backend_pb2.Result(message="Video generated successfully", success=True) |
| 682 | + |
| 683 | + except Exception as err: |
| 684 | + print(f"Error generating video: {err}", file=sys.stderr) |
| 685 | + traceback.print_exc() |
| 686 | + return backend_pb2.Result(success=False, message=f"Error generating video: {err}") |
| 687 | + |
578 | 688 |
|
579 | 689 | def serve(address): |
580 | 690 | server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), |
|
0 commit comments