3
3
import itertools
4
4
import random
5
5
import json
6
+ import logging
6
7
import math
7
8
import os
8
9
from contextlib import nullcontext
19
20
from accelerate .utils import set_seed
20
21
from diffusers import AutoencoderKL , DDIMScheduler , DDPMScheduler , StableDiffusionPipeline , UNet2DConditionModel
21
22
from diffusers .optimization import get_scheduler
23
+ from diffusers .utils .import_utils import is_xformers_available
22
24
from huggingface_hub import HfFolder , Repository , whoami
23
25
from PIL import Image
24
26
from torchvision import transforms
@@ -111,7 +113,7 @@ def parse_args(input_args=None):
111
113
parser .add_argument (
112
114
"--save_infer_steps" ,
113
115
type = int ,
114
- default = 50 ,
116
+ default = 20 ,
115
117
help = "The number of inference steps for save sample." ,
116
118
)
117
119
parser .add_argument (
@@ -252,6 +254,11 @@ def parse_args(input_args=None):
252
254
default = None ,
253
255
help = "Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc." ,
254
256
)
257
+ parser .add_argument (
258
+ "--read_prompts_from_txts" ,
259
+ action = "store_true" ,
260
+ help = "Use prompt per image. Put prompts in the same directory as images, e.g. for image.png create image.png.txt." ,
261
+ )
255
262
256
263
if input_args is not None :
257
264
args = parser .parse_args (input_args )
@@ -280,19 +287,25 @@ def __init__(
280
287
center_crop = False ,
281
288
num_class_images = None ,
282
289
pad_tokens = False ,
283
- hflip = False
290
+ hflip = False ,
291
+ read_prompts_from_txts = False ,
284
292
):
285
293
self .size = size
286
294
self .center_crop = center_crop
287
295
self .tokenizer = tokenizer
288
296
self .with_prior_preservation = with_prior_preservation
289
297
self .pad_tokens = pad_tokens
298
+ self .read_prompts_from_txts = read_prompts_from_txts
290
299
291
300
self .instance_images_path = []
292
301
self .class_images_path = []
293
302
294
303
for concept in concepts_list :
295
- inst_img_path = [(x , concept ["instance_prompt" ]) for x in Path (concept ["instance_data_dir" ]).iterdir () if x .is_file ()]
304
+ inst_img_path = [
305
+ (x , concept ["instance_prompt" ])
306
+ for x in Path (concept ["instance_data_dir" ]).iterdir ()
307
+ if x .is_file () and not str (x ).endswith (".txt" )
308
+ ]
296
309
self .instance_images_path .extend (inst_img_path )
297
310
298
311
if with_prior_preservation :
@@ -320,9 +333,15 @@ def __len__(self):
320
333
def __getitem__ (self , index ):
321
334
example = {}
322
335
instance_path , instance_prompt = self .instance_images_path [index % self .num_instance_images ]
336
+
337
+ if self .read_prompts_from_txts :
338
+ with open (str (instance_path ) + ".txt" ) as f :
339
+ instance_prompt = f .read ().strip ()
340
+
323
341
instance_image = Image .open (instance_path )
324
342
if not instance_image .mode == "RGB" :
325
343
instance_image = instance_image .convert ("RGB" )
344
+
326
345
example ["instance_images" ] = self .image_transforms (instance_image )
327
346
example ["instance_prompt_ids" ] = self .tokenizer (
328
347
instance_prompt ,
@@ -410,6 +429,12 @@ def main(args):
410
429
logging_dir = logging_dir ,
411
430
)
412
431
432
+ logging .basicConfig (
433
+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
434
+ datefmt = "%m/%d/%Y %H:%M:%S" ,
435
+ level = logging .INFO ,
436
+ )
437
+
413
438
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
414
439
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
415
440
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
@@ -457,6 +482,9 @@ def main(args):
457
482
safety_checker = None ,
458
483
revision = args .revision
459
484
)
485
+ pipeline .scheduler = DDIMScheduler .from_config (pipeline .scheduler .config )
486
+ if is_xformers_available ():
487
+ pipeline .enable_xformers_memory_efficient_attention ()
460
488
pipeline .set_progress_bar_config (disable = True )
461
489
pipeline .to (accelerator .device )
462
490
@@ -472,7 +500,10 @@ def main(args):
472
500
for example in tqdm (
473
501
sample_dataloader , desc = "Generating class images" , disable = not accelerator .is_local_main_process
474
502
):
475
- images = pipeline (example ["prompt" ]).images
503
+ images = pipeline (
504
+ example ["prompt" ],
505
+ num_inference_steps = args .save_infer_steps
506
+ ).images
476
507
477
508
for i , image in enumerate (images ):
478
509
hash_image = hashlib .sha1 (image .tobytes ()).hexdigest ()
@@ -518,6 +549,12 @@ def main(args):
518
549
if not args .train_text_encoder :
519
550
text_encoder .requires_grad_ (False )
520
551
552
+ if is_xformers_available ():
553
+ vae .enable_xformers_memory_efficient_attention ()
554
+ unet .enable_xformers_memory_efficient_attention ()
555
+ else :
556
+ logger .warning ("xformers is not available. Make sure it is installed correctly" )
557
+
521
558
if args .gradient_checkpointing :
522
559
unet .enable_gradient_checkpointing ()
523
560
if args .train_text_encoder :
@@ -562,7 +599,8 @@ def main(args):
562
599
center_crop = args .center_crop ,
563
600
num_class_images = args .num_class_images ,
564
601
pad_tokens = args .pad_tokens ,
565
- hflip = args .hflip
602
+ hflip = args .hflip ,
603
+ read_prompts_from_txts = args .read_prompts_from_txts ,
566
604
)
567
605
568
606
def collate_fn (examples ):
@@ -679,24 +717,25 @@ def save_weights(step):
679
717
# Create the pipeline using using the trained modules and save it.
680
718
if accelerator .is_main_process :
681
719
if args .train_text_encoder :
682
- text_enc_model = accelerator .unwrap_model (text_encoder )
720
+ text_enc_model = accelerator .unwrap_model (text_encoder , keep_fp32_wrapper = True )
683
721
else :
684
722
text_enc_model = CLIPTextModel .from_pretrained (args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision )
685
- scheduler = DDIMScheduler (beta_start = 0.00085 , beta_end = 0.012 , beta_schedule = "scaled_linear" , clip_sample = False , set_alpha_to_one = False )
686
723
pipeline = StableDiffusionPipeline .from_pretrained (
687
724
args .pretrained_model_name_or_path ,
688
- unet = accelerator .unwrap_model (unet ),
725
+ unet = accelerator .unwrap_model (unet , keep_fp32_wrapper = True ),
689
726
text_encoder = text_enc_model ,
690
727
vae = AutoencoderKL .from_pretrained (
691
728
args .pretrained_vae_name_or_path or args .pretrained_model_name_or_path ,
692
729
subfolder = None if args .pretrained_vae_name_or_path else "vae" ,
693
730
revision = None if args .pretrained_vae_name_or_path else args .revision ,
694
731
),
695
732
safety_checker = None ,
696
- scheduler = scheduler ,
697
733
torch_dtype = torch .float16 ,
698
734
revision = args .revision ,
699
735
)
736
+ pipeline .scheduler = DDIMScheduler .from_config (pipeline .scheduler .config )
737
+ if is_xformers_available ():
738
+ pipeline .enable_xformers_memory_efficient_attention ()
700
739
save_dir = os .path .join (args .output_dir , f"{ step } " )
701
740
pipeline .save_pretrained (save_dir )
702
741
with open (os .path .join (save_dir , "args.json" ), "w" ) as f :
@@ -765,23 +804,31 @@ def save_weights(step):
765
804
encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
766
805
767
806
# Predict the noise residual
768
- noise_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
807
+ model_pred = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
808
+
809
+ # Get the target for loss depending on the prediction type
810
+ if noise_scheduler .config .prediction_type == "epsilon" :
811
+ target = noise
812
+ elif noise_scheduler .config .prediction_type == "v_prediction" :
813
+ target = noise_scheduler .get_velocity (latents , noise , timesteps )
814
+ else :
815
+ raise ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type } " )
769
816
770
817
if args .with_prior_preservation :
771
- # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
772
- noise_pred , noise_pred_prior = torch .chunk (noise_pred , 2 , dim = 0 )
773
- noise , noise_prior = torch .chunk (noise , 2 , dim = 0 )
818
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
819
+ model_pred , model_pred_prior = torch .chunk (model_pred , 2 , dim = 0 )
820
+ target , target_prior = torch .chunk (target , 2 , dim = 0 )
774
821
775
822
# Compute instance loss
776
- loss = F .mse_loss (noise_pred .float (), noise .float (), reduction = "none" ). mean ([ 1 , 2 , 3 ]). mean ( )
823
+ loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
777
824
778
825
# Compute prior loss
779
- prior_loss = F .mse_loss (noise_pred_prior .float (), noise_prior .float (), reduction = "mean" )
826
+ prior_loss = F .mse_loss (model_pred_prior .float (), target_prior .float (), reduction = "mean" )
780
827
781
828
# Add the prior loss to the instance loss.
782
829
loss = loss + args .prior_loss_weight * prior_loss
783
830
else :
784
- loss = F .mse_loss (noise_pred .float (), noise .float (), reduction = "mean" )
831
+ loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
785
832
786
833
accelerator .backward (loss )
787
834
# if accelerator.sync_gradients:
0 commit comments