Skip to content

Commit

Permalink
Make InstructPix2Pix Training Script torch.compile compatible (huggin…
Browse files Browse the repository at this point in the history
…gface#6558)

* added torch.compile for pix2pix

* required changes
  • Loading branch information
charchit7 authored Jan 15, 2024
1 parent 08702fc commit b053053
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions examples/instruct_pix2pix/train_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -489,6 +490,11 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -845,7 +851,7 @@ def collate_fn(examples):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

# Predict the noise residual and compute loss
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

# Gather the losses across all processes for logging (if we use distributed training).
Expand Down Expand Up @@ -919,9 +925,9 @@ def collate_fn(examples):
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
unet=unwrap_model(unet),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
Expand Down Expand Up @@ -965,14 +971,14 @@ def collate_fn(examples):
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())

pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=unet,
revision=args.revision,
variant=args.variant,
Expand Down

0 comments on commit b053053

Please sign in to comment.