|
2 | 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
3 | 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 | 4 | """
|
5 |
| -from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images |
| 5 | +from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer |
6 | 6 | import argparse
|
7 | 7 | from torch.autograd import Variable
|
8 | 8 | from trainer import MUNIT_Trainer, UNIT_Trainer
|
|
41 | 41 | sys.exit("Only support MUNIT|UNIT")
|
42 | 42 | trainer.cuda()
|
43 | 43 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
|
44 |
| -train_display_images_a = Variable(torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()) |
45 |
| -train_display_images_b = Variable(torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()) |
46 |
| -test_display_images_a = Variable(torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()) |
47 |
| -test_display_images_b = Variable(torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()) |
| 44 | +train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda() |
| 45 | +train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda() |
| 46 | +test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda() |
| 47 | +test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda() |
48 | 48 |
|
49 | 49 | # Setup logger and output folders
|
50 | 50 | model_name = os.path.splitext(os.path.basename(opts.config))[0]
|
|
58 | 58 | while True:
|
59 | 59 | for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
|
60 | 60 | trainer.update_learning_rate()
|
61 |
| - images_a, images_b = Variable(images_a.cuda()), Variable(images_b.cuda()) |
| 61 | + images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach() |
62 | 62 |
|
63 |
| - # Main training code |
64 |
| - trainer.dis_update(images_a, images_b, config) |
65 |
| - trainer.gen_update(images_a, images_b, config) |
| 63 | + with Timer("Elapsed time in update: %f"): |
| 64 | + # Main training code |
| 65 | + trainer.dis_update(images_a, images_b, config) |
| 66 | + trainer.gen_update(images_a, images_b, config) |
| 67 | + torch.cuda.synchronize() |
66 | 68 |
|
67 | 69 | # Dump training stats in log file
|
68 | 70 | if (iterations + 1) % config['log_iter'] == 0:
|
|
0 commit comments