From a8494e2876cc4d6788c8b853911d69d84aa122b7 Mon Sep 17 00:00:00 2001 From: jlrosende <10084165+jlrosende@users.noreply.github.com> Date: Tue, 24 Jan 2023 10:49:12 +0100 Subject: [PATCH] Fix save model after training The model is not save to output model. With this change, the model now is saved again. --- examples/dreambooth/train_dreambooth.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3dbd9287bd9..0631cc5dee54 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -783,6 +783,17 @@ def bar(prg): txt_dir=args.output_dir + "/text_encoder_trained" if os.path.exists(txt_dir): subprocess.call('rm -r '+txt_dir, shell=True) + else: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + ) + frz_dir=args.output_dir + "/text_encoder_frozen" + pipeline.save_pretrained(args.output_dir) + if args.train_text_encoder and os.path.exists(frz_dir): + subprocess.call('mv -f '+frz_dir +'/*.* '+ args.output_dir+'/text_encoder', shell=True) + subprocess.call('rm -r '+ frz_dir, shell=True) if os.path.exists(args.captions_dir+'off'): subprocess.call('mv '+args.captions_dir+'off '+args.captions_dir, shell=True)