Skip to content

Commit

Permalink
[Training] fix training resuming problem when using FP16 (SDXL LoRA D…
Browse files Browse the repository at this point in the history
…reamBooth) (huggingface#6514)

* fix: training resume from fp16.

* add: comment

* remove residue from another branch.

* remove more residues.

* thanks to Younes; no hacks.

* style.

* clean things a bit and modularize _set_state_dict_into_text_encoder

* add comment about the fix detailed.
  • Loading branch information
sayakpaul authored Jan 12, 2024
1 parent 7d63182 commit 79df503
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
74 changes: 51 additions & 23 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
Expand All @@ -53,8 +53,13 @@
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand Down Expand Up @@ -997,17 +1002,6 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
Expand Down Expand Up @@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
)

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand All @@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

if args.train_text_encoder:
Expand Down Expand Up @@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ def load_lora_into_text_encoder(
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)

lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
Expand Down
30 changes: 29 additions & 1 deletion src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@
from torchvision import transforms

from .models import UNet2DConditionModel
from .utils import deprecate, is_transformers_available
from .utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_peft_available,
is_transformers_available,
)


if is_transformers_available():
import transformers

if is_peft_available():
from peft import set_peft_model_state_dict


def set_seed(seed: int):
"""
Expand Down Expand Up @@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
return lora_state_dict


def _set_state_dict_into_text_encoder(
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
):
"""
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
Args:
lora_state_dict: The state dictionary to be set.
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
text_encoder: Where the `lora_state_dict` is to be set.
"""

text_encoder_state_dict = {
f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Expand Down

0 comments on commit 79df503

Please sign in to comment.