|
49 | 49 | from diffusers.training_utils import EMAModel |
50 | 50 | from diffusers.utils import check_min_version, deprecate, is_wandb_available |
51 | 51 | from diffusers.utils.import_utils import is_xformers_available |
| 52 | +from diffusers.utils.torch_utils import is_compiled_module |
52 | 53 |
|
53 | 54 |
|
54 | 55 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. |
@@ -489,6 +490,11 @@ def main(): |
489 | 490 | else: |
490 | 491 | raise ValueError("xformers is not available. Make sure it is installed correctly") |
491 | 492 |
|
| 493 | + def unwrap_model(model): |
| 494 | + model = accelerator.unwrap_model(model) |
| 495 | + model = model._orig_mod if is_compiled_module(model) else model |
| 496 | + return model |
| 497 | + |
492 | 498 | # `accelerate` 0.16.0 will have better support for customized saving |
493 | 499 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): |
494 | 500 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format |
@@ -845,7 +851,7 @@ def collate_fn(examples): |
845 | 851 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
846 | 852 |
|
847 | 853 | # Predict the noise residual and compute loss |
848 | | - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample |
| 854 | + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] |
849 | 855 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") |
850 | 856 |
|
851 | 857 | # Gather the losses across all processes for logging (if we use distributed training). |
@@ -919,9 +925,9 @@ def collate_fn(examples): |
919 | 925 | # The models need unwrapping because for compatibility in distributed training mode. |
920 | 926 | pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( |
921 | 927 | args.pretrained_model_name_or_path, |
922 | | - unet=accelerator.unwrap_model(unet), |
923 | | - text_encoder=accelerator.unwrap_model(text_encoder), |
924 | | - vae=accelerator.unwrap_model(vae), |
| 928 | + unet=unwrap_model(unet), |
| 929 | + text_encoder=unwrap_model(text_encoder), |
| 930 | + vae=unwrap_model(vae), |
925 | 931 | revision=args.revision, |
926 | 932 | variant=args.variant, |
927 | 933 | torch_dtype=weight_dtype, |
@@ -965,14 +971,14 @@ def collate_fn(examples): |
965 | 971 | # Create the pipeline using the trained modules and save it. |
966 | 972 | accelerator.wait_for_everyone() |
967 | 973 | if accelerator.is_main_process: |
968 | | - unet = accelerator.unwrap_model(unet) |
| 974 | + unet = unwrap_model(unet) |
969 | 975 | if args.use_ema: |
970 | 976 | ema_unet.copy_to(unet.parameters()) |
971 | 977 |
|
972 | 978 | pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( |
973 | 979 | args.pretrained_model_name_or_path, |
974 | | - text_encoder=accelerator.unwrap_model(text_encoder), |
975 | | - vae=accelerator.unwrap_model(vae), |
| 980 | + text_encoder=unwrap_model(text_encoder), |
| 981 | + vae=unwrap_model(vae), |
976 | 982 | unet=unet, |
977 | 983 | revision=args.revision, |
978 | 984 | variant=args.variant, |
|
0 commit comments