@@ -332,17 +332,22 @@ def save_checkpoint(self, epoch, is_final=False):
332
332
333
333
def load_checkpoint (self , checkpoint_path ):
334
334
try :
335
- checkpoint = self . accelerator . load (checkpoint_path )
335
+ checkpoint = torch . load (checkpoint_path , map_location = self . accelerator . device )
336
336
337
- self .model .load_state_dict (checkpoint ['model_state_dict' ])
338
- self .discriminator .load_state_dict (checkpoint ['discriminator_state_dict' ])
337
+ # Unwrap the models before loading state dict
338
+ unwrapped_model = self .accelerator .unwrap_model (self .model )
339
+ unwrapped_discriminator = self .accelerator .unwrap_model (self .discriminator )
340
+
341
+ unwrapped_model .load_state_dict (checkpoint ['model_state_dict' ])
342
+ unwrapped_discriminator .load_state_dict (checkpoint ['discriminator_state_dict' ])
339
343
self .optimizer_g .load_state_dict (checkpoint ['optimizer_g_state_dict' ])
340
344
self .optimizer_d .load_state_dict (checkpoint ['optimizer_d_state_dict' ])
341
345
self .scheduler_g .load_state_dict (checkpoint ['scheduler_g_state_dict' ])
342
346
self .scheduler_d .load_state_dict (checkpoint ['scheduler_d_state_dict' ])
343
347
344
348
if self .ema and 'ema_state_dict' in checkpoint :
345
- self .ema .load_state_dict (checkpoint ['ema_state_dict' ])
349
+ unwrapped_ema = self .accelerator .unwrap_model (self .ema )
350
+ unwrapped_ema .load_state_dict (checkpoint ['ema_state_dict' ])
346
351
347
352
start_epoch = checkpoint ['epoch' ] + 1
348
353
print (f"Loaded checkpoint from epoch { start_epoch - 1 } " )
@@ -398,6 +403,9 @@ def main():
398
403
collate_fn = gpu_padded_collate
399
404
)
400
405
406
+ print ("using float32 for onnx training...." )
407
+ torch .set_default_dtype (torch .float32 )
408
+
401
409
402
410
trainer = IMFTrainer (config , model , discriminator , dataloader , accelerator )
403
411
# Check if a checkpoint path is provided in the config
0 commit comments