@@ -287,6 +287,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
287
287
training_models .append (text_encoder2 )
288
288
# set require_grad=True later
289
289
else :
290
+ text_encoder1 .to (weight_dtype )
291
+ text_encoder2 .to (weight_dtype )
290
292
text_encoder1 .requires_grad_ (False )
291
293
text_encoder2 .requires_grad_ (False )
292
294
text_encoder1 .eval ()
@@ -295,7 +297,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
295
297
# TextEncoderの出力をキャッシュする
296
298
if args .cache_text_encoder_outputs :
297
299
# Text Encodes are eval and no grad
298
- with torch .no_grad ():
300
+ with torch .no_grad (), accelerator . autocast () :
299
301
train_dataset_group .cache_text_encoder_outputs (
300
302
(tokenizer1 , tokenizer2 ),
301
303
(text_encoder1 , text_encoder2 ),
@@ -315,25 +317,23 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
315
317
m .requires_grad_ (True )
316
318
317
319
if block_lrs is None :
318
- params = []
319
- for m in training_models :
320
- params .extend (m .parameters ())
321
- params_to_optimize = params
322
-
323
- # calculate number of trainable parameters
324
- n_params = 0
325
- for p in params :
326
- n_params += p .numel ()
320
+ params_to_optimize = [
321
+ {"params" : list (training_models [0 ].parameters ()), "lr" : args .learning_rate },
322
+ ]
327
323
else :
328
324
params_to_optimize = get_block_params_to_optimize (training_models [0 ], block_lrs ) # U-Net
329
- for m in training_models [1 :]: # Text Encoders if exists
330
- params_to_optimize .append ({"params" : m .parameters (), "lr" : args .learning_rate })
331
325
332
- # calculate number of trainable parameters
333
- n_params = 0
334
- for params in params_to_optimize :
335
- for p in params ["params" ]:
336
- n_params += p .numel ()
326
+ for m in training_models [1 :]: # Text Encoders if exists
327
+ params_to_optimize .append ({
328
+ "params" : list (m .parameters ()),
329
+ "lr" : args .learning_rate_te or args .learning_rate
330
+ })
331
+
332
+ # calculate number of trainable parameters
333
+ n_params = 0
334
+ for params in params_to_optimize :
335
+ for p in params ["params" ]:
336
+ n_params += p .numel ()
337
337
338
338
accelerator .print (f"number of models: { len (training_models )} " )
339
339
accelerator .print (f"number of trainable parameters: { n_params } " )
@@ -396,8 +396,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
396
396
else :
397
397
unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (unet , optimizer , train_dataloader , lr_scheduler )
398
398
(unet ,) = train_util .transform_models_if_DDP ([unet ])
399
- text_encoder1 .to (weight_dtype )
400
- text_encoder2 .to (weight_dtype )
401
399
402
400
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
403
401
if args .cache_text_encoder_outputs :
@@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser:
728
726
config_util .add_config_arguments (parser )
729
727
custom_train_functions .add_custom_train_arguments (parser )
730
728
sdxl_train_util .add_sdxl_training_arguments (parser )
729
+ parser .add_argument ("--learning_rate_te" , type = float , default = 0.0 , help = "learning rate for text encoder" )
731
730
732
731
parser .add_argument ("--diffusers_xformers" , action = "store_true" , help = "use xformers by diffusers / Diffusersでxformersを使用する" )
733
732
parser .add_argument ("--train_text_encoder" , action = "store_true" , help = "train text encoder / text encoderも学習する" )
0 commit comments