From b053053ac9047b77a4f7d9ab96c47543e3c017ea Mon Sep 17 00:00:00 2001 From: Charchit Sharma Date: Mon, 15 Jan 2024 17:03:22 +0530 Subject: [PATCH] Make InstructPix2Pix Training Script torch.compile compatible (#6558) * added torch.compile for pix2pix * required changes --- .../train_instruct_pix2pix.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 78cb7bc2f9d0..2af858cfd0ee 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -49,6 +49,7 @@ from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -489,6 +490,11 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -845,7 +851,7 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -919,9 +925,9 @@ def collate_fn(examples): # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + unet=unwrap_model(unet), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -965,14 +971,14 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), unet=unet, revision=args.revision, variant=args.variant,