Skip to content

Commit 0c2abb1

Browse files
committed
fix for resuming
1 parent 8197b79 commit 0c2abb1

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ profiling:
1111
profile_step: 10
1212
training:
1313

14-
load_checkpoint: False # Set this to true when you want to load from a checkpoint
14+
load_checkpoint: True # Set this to true when you want to load from a checkpoint
1515
checkpoint_path: './checkpoints/checkpoint.pth'
1616
use_eye_loss: False
1717
use_subsampling: False # saves ram? https://github.com/johndpope/MegaPortrait-hack/issues/41

train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,17 +332,22 @@ def save_checkpoint(self, epoch, is_final=False):
332332

333333
def load_checkpoint(self, checkpoint_path):
334334
try:
335-
checkpoint = self.accelerator.load(checkpoint_path)
335+
checkpoint = torch.load(checkpoint_path, map_location=self.accelerator.device)
336336

337-
self.model.load_state_dict(checkpoint['model_state_dict'])
338-
self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
337+
# Unwrap the models before loading state dict
338+
unwrapped_model = self.accelerator.unwrap_model(self.model)
339+
unwrapped_discriminator = self.accelerator.unwrap_model(self.discriminator)
340+
341+
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
342+
unwrapped_discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
339343
self.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
340344
self.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
341345
self.scheduler_g.load_state_dict(checkpoint['scheduler_g_state_dict'])
342346
self.scheduler_d.load_state_dict(checkpoint['scheduler_d_state_dict'])
343347

344348
if self.ema and 'ema_state_dict' in checkpoint:
345-
self.ema.load_state_dict(checkpoint['ema_state_dict'])
349+
unwrapped_ema = self.accelerator.unwrap_model(self.ema)
350+
unwrapped_ema.load_state_dict(checkpoint['ema_state_dict'])
346351

347352
start_epoch = checkpoint['epoch'] + 1
348353
print(f"Loaded checkpoint from epoch {start_epoch - 1}")
@@ -398,6 +403,9 @@ def main():
398403
collate_fn=gpu_padded_collate
399404
)
400405

406+
print("using float32 for onnx training....")
407+
torch.set_default_dtype(torch.float32)
408+
401409

402410
trainer = IMFTrainer(config, model, discriminator, dataloader, accelerator)
403411
# Check if a checkpoint path is provided in the config

0 commit comments

Comments
 (0)