Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions hyimage/diffusion/pipelines/hunyuanimage_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class HunyuanImagePipelineConfig:
version: str = ""

@classmethod
def create_default(cls, version: str = "v2.1", use_distilled: bool = False, **kwargs):
def create_default(cls, version: str = "v2.1", use_distilled: bool = False, use_compile: bool = True, **kwargs):
"""
Create a default configuration for specified HunyuanImage version.

Expand All @@ -71,7 +71,11 @@ def create_default(cls, version: str = "v2.1", use_distilled: bool = False, **kw
HUNYUANIMAGE_V2_1_VAE_32x,
HUNYUANIMAGE_V2_1_TEXT_ENCODER,
)
dit_config = HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL() if use_distilled else HUNYUANIMAGE_V2_1_DIT()
dit_config = (
HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(use_compile=use_compile)
if use_distilled else
HUNYUANIMAGE_V2_1_DIT(use_compile=use_compile)
)
return cls(
dit_config=dit_config,
vae_config=HUNYUANIMAGE_V2_1_VAE_32x(),
Expand Down Expand Up @@ -820,6 +824,10 @@ def to(self, device: str | torch.device):
self.text_encoder = self.text_encoder.to(device, non_blocking=True)
if self.vae is not None:
self.vae = self.vae.to(device, non_blocking=True)
if self.use_byt5 and self.byt5_kwargs is not None:
self.byt5_kwargs['byt5_model'] = self.byt5_kwargs['byt5_model'].to(
device, non_blocking=True
)
return self

def update_config(self, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions hyimage/models/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def HUNYUANIMAGE_V2_1_DIT(**kwargs):
use_cpu_offload=False,
gradient_checkpointing=True,
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1.safetensors",
use_compile=True,
use_compile=kwargs.get("use_compile", True),
)


Expand All @@ -77,7 +77,7 @@ def HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(**kwargs):
use_cpu_offload=False,
gradient_checkpointing=True,
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1-distilled.safetensors",
use_compile=True,
use_compile=kwargs.get("use_compile", True),
)

# =============================================================================
Expand All @@ -91,7 +91,7 @@ def HUNYUANIMAGE_REFINER_DIT(**kwargs):
use_cpu_offload=False,
gradient_checkpointing=True,
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage-refiner.safetensors",
use_compile=True,
use_compile=kwargs.get("use_compile", True),
)

def HUNYUANIMAGE_REFINER_VAE_16x(**kwargs):
Expand Down
4 changes: 2 additions & 2 deletions hyimage/models/text_encoder/byT5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def create_byt5(args, device):

# Load custom checkpoint if provided
if args['byT5_ckpt_path'] is not None:
if "cuda" not in str(device):
if "cuda" not in str(device) and "cpu" not in str(device):
byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=f"cuda:{device}")
else:
byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=device)
Expand Down Expand Up @@ -149,7 +149,7 @@ def load_byt5_and_byt5_tokenizer(
cache_dir=huggingface_cache_dir,
).get_encoder()

if "cuda" not in str(device):
if "cuda" not in str(device) and "cpu" not in str(device):
device = torch.device(f"cuda:{device}")
else:
device = torch.device(device)
Expand Down