Skip to content

Commit 1cefb2a

Browse files
authoredOct 28, 2023
Better implementation for te autocast (kohya-ss#895)
* Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * Better cache TE and TE lr * Fix with list * Add timeout settings * Fix arg style
1 parent 202f2c3 commit 1cefb2a

File tree

4 files changed

+41
-30
lines changed

4 files changed

+41
-30
lines changed
 

‎library/train_util.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import ast
55
import asyncio
6+
import datetime
67
import importlib
78
import json
89
import pathlib
@@ -18,7 +19,7 @@
1819
Tuple,
1920
Union,
2021
)
21-
from accelerate import Accelerator
22+
from accelerate import Accelerator, InitProcessGroupKwargs
2223
import gc
2324
import glob
2425
import math
@@ -2855,6 +2856,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
28552856
parser.add_argument(
28562857
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
28572858
) # TODO move to SDXL training, because it is not supported by SD1/2
2859+
parser.add_argument(
2860+
"--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)",
2861+
)
28582862
parser.add_argument(
28592863
"--clip_skip",
28602864
type=int,
@@ -3786,6 +3790,7 @@ def prepare_accelerator(args: argparse.Namespace):
37863790
mixed_precision=args.mixed_precision,
37873791
log_with=log_with,
37883792
project_dir=logging_dir,
3793+
kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))],
37893794
)
37903795
return accelerator
37913796

‎sdxl_train.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
287287
training_models.append(text_encoder2)
288288
# set require_grad=True later
289289
else:
290+
text_encoder1.to(weight_dtype)
291+
text_encoder2.to(weight_dtype)
290292
text_encoder1.requires_grad_(False)
291293
text_encoder2.requires_grad_(False)
292294
text_encoder1.eval()
@@ -295,7 +297,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
295297
# TextEncoderの出力をキャッシュする
296298
if args.cache_text_encoder_outputs:
297299
# Text Encodes are eval and no grad
298-
with torch.no_grad():
300+
with torch.no_grad(), accelerator.autocast():
299301
train_dataset_group.cache_text_encoder_outputs(
300302
(tokenizer1, tokenizer2),
301303
(text_encoder1, text_encoder2),
@@ -315,25 +317,23 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
315317
m.requires_grad_(True)
316318

317319
if block_lrs is None:
318-
params = []
319-
for m in training_models:
320-
params.extend(m.parameters())
321-
params_to_optimize = params
322-
323-
# calculate number of trainable parameters
324-
n_params = 0
325-
for p in params:
326-
n_params += p.numel()
320+
params_to_optimize = [
321+
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
322+
]
327323
else:
328324
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
329-
for m in training_models[1:]: # Text Encoders if exists
330-
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
331325

332-
# calculate number of trainable parameters
333-
n_params = 0
334-
for params in params_to_optimize:
335-
for p in params["params"]:
336-
n_params += p.numel()
326+
for m in training_models[1:]: # Text Encoders if exists
327+
params_to_optimize.append({
328+
"params": list(m.parameters()),
329+
"lr": args.learning_rate_te or args.learning_rate
330+
})
331+
332+
# calculate number of trainable parameters
333+
n_params = 0
334+
for params in params_to_optimize:
335+
for p in params["params"]:
336+
n_params += p.numel()
337337

338338
accelerator.print(f"number of models: {len(training_models)}")
339339
accelerator.print(f"number of trainable parameters: {n_params}")
@@ -396,8 +396,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
396396
else:
397397
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
398398
(unet,) = train_util.transform_models_if_DDP([unet])
399-
text_encoder1.to(weight_dtype)
400-
text_encoder2.to(weight_dtype)
401399

402400
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
403401
if args.cache_text_encoder_outputs:
@@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser:
728726
config_util.add_config_arguments(parser)
729727
custom_train_functions.add_custom_train_arguments(parser)
730728
sdxl_train_util.add_sdxl_training_arguments(parser)
729+
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
731730

732731
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
733732
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")

‎sdxl_train_network.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,16 @@ def cache_text_encoder_outputs_if_needed(
7070
if torch.cuda.is_available():
7171
torch.cuda.empty_cache()
7272

73-
dataset.cache_text_encoder_outputs(
74-
tokenizers,
75-
text_encoders,
76-
accelerator.device,
77-
weight_dtype,
78-
args.cache_text_encoder_outputs_to_disk,
79-
accelerator.is_main_process,
80-
)
73+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
74+
with accelerator.autocast():
75+
dataset.cache_text_encoder_outputs(
76+
tokenizers,
77+
text_encoders,
78+
accelerator.device,
79+
weight_dtype,
80+
args.cache_text_encoder_outputs_to_disk,
81+
accelerator.is_main_process,
82+
)
8183

8284
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
8385
text_encoders[1].to("cpu", dtype=torch.float32)

‎train_network.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def load_tokenizer(self, args):
109109
def is_text_encoder_outputs_cached(self, args):
110110
return False
111111

112+
def is_train_text_encoder(self, args):
113+
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
114+
112115
def cache_text_encoder_outputs_if_needed(
113116
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
114117
):
@@ -310,7 +313,7 @@ def train(self, args):
310313
args.scale_weight_norms = False
311314

312315
train_unet = not args.network_train_text_encoder_only
313-
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
316+
train_text_encoder = self.is_train_text_encoder(args)
314317
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
315318

316319
if args.network_weights is not None:
@@ -403,6 +406,8 @@ def train(self, args):
403406
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
404407
unet, network, optimizer, train_dataloader, lr_scheduler
405408
)
409+
for t_enc in text_encoders:
410+
t_enc.to(accelerator.device, dtype=weight_dtype)
406411
elif train_text_encoder:
407412
if len(text_encoders) > 1:
408413
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -767,7 +772,7 @@ def remove_model(old_ckpt_name):
767772
latents = latents * self.vae_scale_factor
768773
b_size = latents.shape[0]
769774

770-
with torch.set_grad_enabled(train_text_encoder):
775+
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
771776
# Get the text embedding for conditioning
772777
if args.weighted_captions:
773778
text_encoder_conds = get_weighted_text_embeddings(

0 commit comments

Comments
 (0)
Please sign in to comment.