From 5a3e13c104e7511a2e0a97ba71a011f33dd02e72 Mon Sep 17 00:00:00 2001 From: Plat Date: Fri, 21 Jul 2023 21:49:26 +0900 Subject: [PATCH 1/4] fix: #15 --- prompt_util.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/prompt_util.py b/prompt_util.py index 0da68e6..afe545d 100644 --- a/prompt_util.py +++ b/prompt_util.py @@ -18,9 +18,10 @@ class PromptEmbedsXL: text_embeds: torch.FloatTensor pooled_embeds: torch.FloatTensor - def __init__(self, *args) -> None: - self.text_embeds = args[0] - self.pooled_embeds = args[1] + def __init__( + self, text_embeds_pair: tuple[torch.FloatTensor, torch.FloatTensor] + ) -> None: + self.text_embeds, self.pooled_embeds = text_embeds_pair # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL From 4b19877cadf058825163576c7d6c0cf9bd8a8ffd Mon Sep 17 00:00:00 2001 From: Plat Date: Sat, 22 Jul 2023 21:24:38 +0900 Subject: [PATCH 2/4] fix: rescale_noise_cfg takes wrong arguments --- train_util.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_util.py b/train_util.py index 9db7b84..58bb53a 100644 --- a/train_util.py +++ b/train_util.py @@ -222,8 +222,8 @@ def predict_noise_xl( text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの add_text_embeddings: torch.FloatTensor, # pooled なやつ add_time_ids: torch.FloatTensor, - guidance_scale=7.5, - guidance_rescale=0.7, + guidance_scale: float = 7.5, + guidance_rescale: float = 0.0, ) -> torch.FloatTensor: # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) @@ -250,9 +250,10 @@ def predict_noise_xl( ) # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 - noise_pred = rescale_noise_cfg( - noise_pred, noise_pred_text, guidance_rescale=guidance_rescale - ) + if guidance_rescale > 0.0: + guided_target = rescale_noise_cfg( + guided_target, noise_pred_text, guidance_rescale=guidance_rescale + ) return guided_target @@ -281,7 +282,6 @@ def diffusion_xl( add_text_embeddings, add_time_ids, guidance_scale=guidance_scale, - guidance_rescale=0.7, ) # compute the previous noisy sample x_t -> x_t-1 From 42c9e5342fd61b6939f8268907445013cc4a6a18 Mon Sep 17 00:00:00 2001 From: plat Date: Sat, 20 Jan 2024 10:50:09 +0000 Subject: [PATCH 3/4] chore: replace deprecated functions --- train_lora_xl.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/train_lora_xl.py b/train_lora_xl.py index b729084..0848b28 100644 --- a/train_lora_xl.py +++ b/train_lora_xl.py @@ -42,8 +42,8 @@ def train( prompts: list[PromptSettings], ): metadata = { - "prompts": ",".join([prompt.json() for prompt in prompts]), - "config": config.json(), + "prompts": ",".join([prompt.model_dump_json() for prompt in prompts]), + "config": config.model_dump_json(), } save_path = Path(config.save.path) @@ -76,8 +76,10 @@ def train( text_encoder.eval() unet.to(DEVICE_CUDA, dtype=weight_dtype) + if config.other.use_xformers: unet.enable_xformers_memory_efficient_attention() + unet.requires_grad_(False) unet.eval() @@ -90,15 +92,17 @@ def train( ).to(DEVICE_CUDA, dtype=weight_dtype) optimizer_module = train_util.get_optimizer(config.train.optimizer) - #optimizer_args + # optimizer_args optimizer_kwargs = {} if config.train.optimizer_args is not None and len(config.train.optimizer_args) > 0: for arg in config.train.optimizer_args.split(" "): key, value = arg.split("=") value = ast.literal_eval(value) optimizer_kwargs[key] = value - - optimizer = optimizer_module(network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs) + + optimizer = optimizer_module( + network.prepare_optimizer_params(), lr=config.train.lr, **optimizer_kwargs + ) lr_scheduler = train_util.get_lr_scheduler( config.train.lr_scheduler, optimizer, From 75909762d03480f64c86b2ea2a0750b080b196f7 Mon Sep 17 00:00:00 2001 From: plat Date: Sat, 20 Jan 2024 10:50:33 +0000 Subject: [PATCH 4/4] fix: device mismatch --- train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_util.py b/train_util.py index da596a1..ec2c738 100644 --- a/train_util.py +++ b/train_util.py @@ -18,7 +18,7 @@ def get_random_noise( - batch_size: int, height: int, width: int, generator: torch.Generator = None + batch_size: int, height: int, width: int, generator: torch.Generator | None = None ) -> torch.Tensor: return torch.randn( ( @@ -46,13 +46,13 @@ def get_initial_latents( height: int, width: int, n_prompts: int, - generator=None, + generator: torch.Generator | None = None, ) -> torch.Tensor: noise = get_random_noise(n_imgs, height, width, generator=generator).repeat( n_prompts, 1, 1, 1 ) - latents = noise * scheduler.init_noise_sigma + latents = noise * scheduler.init_noise_sigma.to(noise.device) return latents @@ -364,7 +364,7 @@ def get_optimizer(name: str): return Lion elif name == "prodigy": import prodigyopt - + return prodigyopt.Prodigy else: raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")