Skip to content

Commit 1e17b7b

Browse files
committed
fix template
1 parent 31fed0b commit 1e17b7b

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

launch.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ conda activate diffusers
55
python heictojpg.py "./data/dog"
66

77
accelerate launch train_dreambooth.py \
8-
--pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
8+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
99
--instance_data_dir="./data/dog" \
1010
--output_dir="fine-tuned-model-output" \
1111
--instance_prompt="adamsmith" \

requirements.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
accelerate==0.12.0
1+
accelerate
22
torchvision
3-
transformers>=4.21.0
3+
transformers>=4.25.1
44
ftfy
55
tensorboard
6-
modelcards
6+
Jinja2
7+
safetensors
8+
xformers

train_dreambooth.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import random
55
import json
6+
import logging
67
import math
78
import os
89
from contextlib import nullcontext
@@ -19,6 +20,7 @@
1920
from accelerate.utils import set_seed
2021
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
2122
from diffusers.optimization import get_scheduler
23+
from diffusers.utils.import_utils import is_xformers_available
2224
from huggingface_hub import HfFolder, Repository, whoami
2325
from PIL import Image
2426
from torchvision import transforms
@@ -111,7 +113,7 @@ def parse_args(input_args=None):
111113
parser.add_argument(
112114
"--save_infer_steps",
113115
type=int,
114-
default=50,
116+
default=20,
115117
help="The number of inference steps for save sample.",
116118
)
117119
parser.add_argument(
@@ -252,6 +254,11 @@ def parse_args(input_args=None):
252254
default=None,
253255
help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
254256
)
257+
parser.add_argument(
258+
"--read_prompts_from_txts",
259+
action="store_true",
260+
help="Use prompt per image. Put prompts in the same directory as images, e.g. for image.png create image.png.txt.",
261+
)
255262

256263
if input_args is not None:
257264
args = parser.parse_args(input_args)
@@ -280,19 +287,25 @@ def __init__(
280287
center_crop=False,
281288
num_class_images=None,
282289
pad_tokens=False,
283-
hflip=False
290+
hflip=False,
291+
read_prompts_from_txts=False,
284292
):
285293
self.size = size
286294
self.center_crop = center_crop
287295
self.tokenizer = tokenizer
288296
self.with_prior_preservation = with_prior_preservation
289297
self.pad_tokens = pad_tokens
298+
self.read_prompts_from_txts = read_prompts_from_txts
290299

291300
self.instance_images_path = []
292301
self.class_images_path = []
293302

294303
for concept in concepts_list:
295-
inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]
304+
inst_img_path = [
305+
(x, concept["instance_prompt"])
306+
for x in Path(concept["instance_data_dir"]).iterdir()
307+
if x.is_file() and not str(x).endswith(".txt")
308+
]
296309
self.instance_images_path.extend(inst_img_path)
297310

298311
if with_prior_preservation:
@@ -320,9 +333,15 @@ def __len__(self):
320333
def __getitem__(self, index):
321334
example = {}
322335
instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
336+
337+
if self.read_prompts_from_txts:
338+
with open(str(instance_path) + ".txt") as f:
339+
instance_prompt = f.read().strip()
340+
323341
instance_image = Image.open(instance_path)
324342
if not instance_image.mode == "RGB":
325343
instance_image = instance_image.convert("RGB")
344+
326345
example["instance_images"] = self.image_transforms(instance_image)
327346
example["instance_prompt_ids"] = self.tokenizer(
328347
instance_prompt,
@@ -410,6 +429,12 @@ def main(args):
410429
logging_dir=logging_dir,
411430
)
412431

432+
logging.basicConfig(
433+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
434+
datefmt="%m/%d/%Y %H:%M:%S",
435+
level=logging.INFO,
436+
)
437+
413438
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
414439
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
415440
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
@@ -457,6 +482,9 @@ def main(args):
457482
safety_checker=None,
458483
revision=args.revision
459484
)
485+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
486+
if is_xformers_available():
487+
pipeline.enable_xformers_memory_efficient_attention()
460488
pipeline.set_progress_bar_config(disable=True)
461489
pipeline.to(accelerator.device)
462490

@@ -472,7 +500,10 @@ def main(args):
472500
for example in tqdm(
473501
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
474502
):
475-
images = pipeline(example["prompt"]).images
503+
images = pipeline(
504+
example["prompt"],
505+
num_inference_steps=args.save_infer_steps
506+
).images
476507

477508
for i, image in enumerate(images):
478509
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
@@ -518,6 +549,12 @@ def main(args):
518549
if not args.train_text_encoder:
519550
text_encoder.requires_grad_(False)
520551

552+
if is_xformers_available():
553+
vae.enable_xformers_memory_efficient_attention()
554+
unet.enable_xformers_memory_efficient_attention()
555+
else:
556+
logger.warning("xformers is not available. Make sure it is installed correctly")
557+
521558
if args.gradient_checkpointing:
522559
unet.enable_gradient_checkpointing()
523560
if args.train_text_encoder:
@@ -562,7 +599,8 @@ def main(args):
562599
center_crop=args.center_crop,
563600
num_class_images=args.num_class_images,
564601
pad_tokens=args.pad_tokens,
565-
hflip=args.hflip
602+
hflip=args.hflip,
603+
read_prompts_from_txts=args.read_prompts_from_txts,
566604
)
567605

568606
def collate_fn(examples):
@@ -679,24 +717,25 @@ def save_weights(step):
679717
# Create the pipeline using using the trained modules and save it.
680718
if accelerator.is_main_process:
681719
if args.train_text_encoder:
682-
text_enc_model = accelerator.unwrap_model(text_encoder)
720+
text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
683721
else:
684722
text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
685-
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
686723
pipeline = StableDiffusionPipeline.from_pretrained(
687724
args.pretrained_model_name_or_path,
688-
unet=accelerator.unwrap_model(unet),
725+
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
689726
text_encoder=text_enc_model,
690727
vae=AutoencoderKL.from_pretrained(
691728
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
692729
subfolder=None if args.pretrained_vae_name_or_path else "vae",
693730
revision=None if args.pretrained_vae_name_or_path else args.revision,
694731
),
695732
safety_checker=None,
696-
scheduler=scheduler,
697733
torch_dtype=torch.float16,
698734
revision=args.revision,
699735
)
736+
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
737+
if is_xformers_available():
738+
pipeline.enable_xformers_memory_efficient_attention()
700739
save_dir = os.path.join(args.output_dir, f"{step}")
701740
pipeline.save_pretrained(save_dir)
702741
with open(os.path.join(save_dir, "args.json"), "w") as f:
@@ -765,23 +804,31 @@ def save_weights(step):
765804
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
766805

767806
# Predict the noise residual
768-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
807+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
808+
809+
# Get the target for loss depending on the prediction type
810+
if noise_scheduler.config.prediction_type == "epsilon":
811+
target = noise
812+
elif noise_scheduler.config.prediction_type == "v_prediction":
813+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
814+
else:
815+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
769816

770817
if args.with_prior_preservation:
771-
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
772-
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
773-
noise, noise_prior = torch.chunk(noise, 2, dim=0)
818+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
819+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
820+
target, target_prior = torch.chunk(target, 2, dim=0)
774821

775822
# Compute instance loss
776-
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
823+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
777824

778825
# Compute prior loss
779-
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
826+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
780827

781828
# Add the prior loss to the instance loss.
782829
loss = loss + args.prior_loss_weight * prior_loss
783830
else:
784-
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
831+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
785832

786833
accelerator.backward(loss)
787834
# if accelerator.sync_gradients:

0 commit comments

Comments
 (0)