diff --git a/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py b/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py index a728024..b3baf53 100644 --- a/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py +++ b/hyimage/diffusion/pipelines/hunyuanimage_pipeline.py @@ -55,7 +55,7 @@ class HunyuanImagePipelineConfig: version: str = "" @classmethod - def create_default(cls, version: str = "v2.1", use_distilled: bool = False, **kwargs): + def create_default(cls, version: str = "v2.1", use_distilled: bool = False, use_compile: bool = True, **kwargs): """ Create a default configuration for specified HunyuanImage version. @@ -71,7 +71,11 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, **kw HUNYUANIMAGE_V2_1_VAE_32x, HUNYUANIMAGE_V2_1_TEXT_ENCODER, ) - dit_config = HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL() if use_distilled else HUNYUANIMAGE_V2_1_DIT() + dit_config = ( + HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(use_compile=use_compile) + if use_distilled else + HUNYUANIMAGE_V2_1_DIT(use_compile=use_compile) + ) return cls( dit_config=dit_config, vae_config=HUNYUANIMAGE_V2_1_VAE_32x(), @@ -820,6 +824,10 @@ def to(self, device: str | torch.device): self.text_encoder = self.text_encoder.to(device, non_blocking=True) if self.vae is not None: self.vae = self.vae.to(device, non_blocking=True) + if self.use_byt5 and self.byt5_kwargs is not None: + self.byt5_kwargs['byt5_model'] = self.byt5_kwargs['byt5_model'].to( + device, non_blocking=True + ) return self def update_config(self, **kwargs): diff --git a/hyimage/models/model_zoo.py b/hyimage/models/model_zoo.py index 3a5eae2..0c88e3a 100644 --- a/hyimage/models/model_zoo.py +++ b/hyimage/models/model_zoo.py @@ -66,7 +66,7 @@ def HUNYUANIMAGE_V2_1_DIT(**kwargs): use_cpu_offload=False, gradient_checkpointing=True, load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1.safetensors", - use_compile=True, + use_compile=kwargs.get("use_compile", True), ) @@ -77,7 +77,7 @@ def HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(**kwargs): use_cpu_offload=False, gradient_checkpointing=True, load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1-distilled.safetensors", - use_compile=True, + use_compile=kwargs.get("use_compile", True), ) # ============================================================================= @@ -91,7 +91,7 @@ def HUNYUANIMAGE_REFINER_DIT(**kwargs): use_cpu_offload=False, gradient_checkpointing=True, load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage-refiner.safetensors", - use_compile=True, + use_compile=kwargs.get("use_compile", True), ) def HUNYUANIMAGE_REFINER_VAE_16x(**kwargs): diff --git a/hyimage/models/text_encoder/byT5/__init__.py b/hyimage/models/text_encoder/byT5/__init__.py index 2372bd0..7a8385b 100644 --- a/hyimage/models/text_encoder/byT5/__init__.py +++ b/hyimage/models/text_encoder/byT5/__init__.py @@ -54,7 +54,7 @@ def create_byt5(args, device): # Load custom checkpoint if provided if args['byT5_ckpt_path'] is not None: - if "cuda" not in str(device): + if "cuda" not in str(device) and "cpu" not in str(device): byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=f"cuda:{device}") else: byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=device) @@ -149,7 +149,7 @@ def load_byt5_and_byt5_tokenizer( cache_dir=huggingface_cache_dir, ).get_encoder() - if "cuda" not in str(device): + if "cuda" not in str(device) and "cpu" not in str(device): device = torch.device(f"cuda:{device}") else: device = torch.device(device)