diff --git a/scripts/train.py b/scripts/train.py index 61b60e02..1531646e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,6 +4,7 @@ import os import sys import time +import copy from collections import defaultdict @@ -324,9 +325,9 @@ def main(args): # Save another checkpoint with model weights and # optimizer state - checkpoint['g_state'] = generator.state_dict() + checkpoint['g_state'] = copy.deepcopy(generator.state_dict()) checkpoint['g_optim_state'] = optimizer_g.state_dict() - checkpoint['d_state'] = discriminator.state_dict() + checkpoint['d_state'] = copy.deepcopy(discriminator.state_dict()) checkpoint['d_optim_state'] = optimizer_d.state_dict() checkpoint_path = os.path.join( args.output_dir, '%s_with_model.pt' % args.checkpoint_name @@ -460,9 +461,9 @@ def check_accuracy( ): d_losses = [] metrics = {} - g_l2_losses_abs, g_l2_losses_rel = ([],) * 2 - disp_error, disp_error_l, disp_error_nl = ([],) * 3 - f_disp_error, f_disp_error_l, f_disp_error_nl = ([],) * 3 + g_l2_losses_abs, g_l2_losses_rel = [], [] + disp_error, disp_error_l, disp_error_nl = [], [], [] + f_disp_error, f_disp_error_l, f_disp_error_nl = [], [], [] total_traj, total_traj_l, total_traj_nl = 0, 0, 0 loss_mask_sum = 0 generator.eval() @@ -578,3 +579,4 @@ def cal_fde( if __name__ == '__main__': args = parser.parse_args() main(args) +