Skip to content

Commit 7fd9594

Browse files
authored
Fix bug in wan-animate and fantasytalking multi gpus inference && Update Training Codes && Fix bug in s2v lora merging && Update qwen image quick loading (#368)
1 parent df77df0 commit 7fd9594

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+4293
-122
lines changed

examples/wan2.2/predict_s2v.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@
345345

346346
if lora_path is not None:
347347
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
348-
pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")
348+
if transformer_2 is not None:
349+
pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")
349350

350351
def save_results():
351352
if not os.path.exists(save_path):

scripts/cogvideox_fun/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,9 +1229,9 @@ def collate_fn(examples):
12291229
ema_transformer3d.to(accelerator.device)
12301230

12311231
# Move text_encode and vae to gpu and cast to weight_dtype
1232-
vae.to(accelerator.device, dtype=weight_dtype)
1232+
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12331233
if not args.enable_text_encoder_in_dataloader:
1234-
text_encoder.to(accelerator.device)
1234+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12351235

12361236
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
12371237
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/cogvideox_fun/train_control.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,9 +1164,9 @@ def collate_fn(examples):
11641164
ema_transformer3d.to(accelerator.device)
11651165

11661166
# Move text_encode and vae to gpu and cast to weight_dtype
1167-
vae.to(accelerator.device, dtype=weight_dtype)
1167+
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
11681168
if not args.enable_text_encoder_in_dataloader:
1169-
text_encoder.to(accelerator.device)
1169+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
11701170

11711171
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11721172
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/cogvideox_fun/train_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,10 +1164,10 @@ def collate_fn(examples):
11641164
)
11651165

11661166
# Move text_encode and vae to gpu and cast to weight_dtype
1167-
vae.to(accelerator.device, dtype=weight_dtype)
1167+
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
11681168
transformer3d.to(accelerator.device, dtype=weight_dtype)
11691169
if not args.enable_text_encoder_in_dataloader:
1170-
text_encoder.to(accelerator.device)
1170+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
11711171

11721172
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
11731173
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/fantasytalking/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,7 @@ def _create_special_list(length):
13571357
# Move text_encode and vae to gpu and cast to weight_dtype
13581358
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
13591359
if not args.enable_text_encoder_in_dataloader:
1360-
text_encoder.to(accelerator.device if not args.low_vram else "cpu")
1360+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
13611361
clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
13621362
audio_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=torch.float32)
13631363

scripts/flux/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,8 +1348,8 @@ def _create_special_list(length):
13481348
# Move text_encode and vae to gpu and cast to weight_dtype
13491349
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
13501350
if not args.enable_text_encoder_in_dataloader:
1351-
text_encoder.to(accelerator.device if not args.low_vram else "cpu")
1352-
text_encoder_2.to(accelerator.device if not args.low_vram else "cpu")
1351+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
1352+
text_encoder_2.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
13531353

13541354
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
13551355
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/flux/train_lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,11 +1280,11 @@ def _create_special_list(length):
12801280
# text_encoder_2 = shard_fn(text_encoder_2)
12811281

12821282
# Move text_encode and vae to gpu and cast to weight_dtype
1283-
vae.to(accelerator.device, dtype=weight_dtype)
1283+
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12841284
transformer3d.to(accelerator.device, dtype=weight_dtype)
12851285
if not args.enable_text_encoder_in_dataloader:
1286-
text_encoder.to(accelerator.device)
1287-
text_encoder_2.to(accelerator.device)
1286+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
1287+
text_encoder_2.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12881288

12891289
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
12901290
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/qwenimage/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ def _create_special_list(length):
12151215
# Move text_encode and vae to gpu and cast to weight_dtype
12161216
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12171217
if not args.enable_text_encoder_in_dataloader:
1218-
text_encoder.to(accelerator.device if not args.low_vram else "cpu")
1218+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12191219

12201220
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
12211221
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/qwenimage/train_edit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ def _create_special_list(length):
12601260
# Move text_encode and vae to gpu and cast to weight_dtype
12611261
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12621262
if not args.enable_text_encoder_in_dataloader:
1263-
text_encoder.to(accelerator.device if not args.low_vram else "cpu")
1263+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12641264

12651265
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
12661266
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

scripts/qwenimage/train_edit_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,10 +1209,10 @@ def _create_special_list(length):
12091209
text_encoder = shard_fn(text_encoder)
12101210

12111211
# Move text_encode and vae to gpu and cast to weight_dtype
1212-
vae.to(accelerator.device, dtype=weight_dtype)
1212+
vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12131213
transformer3d.to(accelerator.device, dtype=weight_dtype)
12141214
if not args.enable_text_encoder_in_dataloader:
1215-
text_encoder.to(accelerator.device)
1215+
text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype)
12161216

12171217
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
12181218
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

0 commit comments

Comments
 (0)