Skip to content

Commit f972e42

Browse files
committed
Fixed speed issue with pytorch 0.4
1 parent 3c47af2 commit f972e42

File tree

3 files changed

+35
-12
lines changed

3 files changed

+35
-12
lines changed

networks.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def __init__(self, input_dim ,output_dim, kernel_size, stride,
305305
if norm == 'bn':
306306
self.norm = nn.BatchNorm2d(norm_dim)
307307
elif norm == 'in':
308-
self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
308+
#self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
309+
self.norm = nn.InstanceNorm2d(norm_dim)
309310
elif norm == 'ln':
310311
self.norm = LayerNorm(norm_dim)
311312
elif norm == 'adain':
@@ -474,6 +475,7 @@ def forward(self, x):
474475
def __repr__(self):
475476
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
476477

478+
477479
class LayerNorm(nn.Module):
478480
def __init__(self, num_features, eps=1e-5, affine=True):
479481
super(LayerNorm, self).__init__()
@@ -487,11 +489,16 @@ def __init__(self, num_features, eps=1e-5, affine=True):
487489

488490
def forward(self, x):
489491
shape = [-1] + [1] * (x.dim() - 1)
490-
mean = x.view(x.size(0), -1).mean(1).view(*shape)
491-
std = x.view(x.size(0), -1).std(1).view(*shape)
492+
# print(x.size())
493+
# mean = x.view(x.size(0), -1).mean(1).view(*shape)
494+
# std = x.view(x.size(0), -1).std(1).view(*shape)
495+
mean = x.view(-1).mean().view(*shape)
496+
std = x.view(-1).std().view(*shape)
497+
492498
x = (x - mean) / (std + self.eps)
493499

494500
if self.affine:
495501
shape = [1, -1] + [1] * (x.dim() - 2)
496502
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
497503
return x
504+

train.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
33
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
44
"""
5-
from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images
5+
from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
66
import argparse
77
from torch.autograd import Variable
88
from trainer import MUNIT_Trainer, UNIT_Trainer
@@ -41,10 +41,10 @@
4141
sys.exit("Only support MUNIT|UNIT")
4242
trainer.cuda()
4343
train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
44-
train_display_images_a = Variable(torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda())
45-
train_display_images_b = Variable(torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda())
46-
test_display_images_a = Variable(torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda())
47-
test_display_images_b = Variable(torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda())
44+
train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()
45+
train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()
46+
test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()
47+
test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()
4848

4949
# Setup logger and output folders
5050
model_name = os.path.splitext(os.path.basename(opts.config))[0]
@@ -58,11 +58,13 @@
5858
while True:
5959
for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
6060
trainer.update_learning_rate()
61-
images_a, images_b = Variable(images_a.cuda()), Variable(images_b.cuda())
61+
images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
6262

63-
# Main training code
64-
trainer.dis_update(images_a, images_b, config)
65-
trainer.gen_update(images_a, images_b, config)
63+
with Timer("Elapsed time in update: %f"):
64+
# Main training code
65+
trainer.dis_update(images_a, images_b, config)
66+
trainer.gen_update(images_a, images_b, config)
67+
torch.cuda.synchronize()
6668

6769
# Dump training stats in log file
6870
if (iterations + 1) % config['log_iter'] == 0:

utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import yaml
1717
import numpy as np
1818
import torch.nn.init as init
19+
import time
1920
# Methods
2021
# get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB)
2122
# get_data_loader_list : list-based data loader
@@ -277,3 +278,16 @@ def init_fun(m):
277278
init.constant_(m.bias.data, 0.0)
278279

279280
return init_fun
281+
282+
283+
class Timer:
284+
def __init__(self, msg):
285+
self.msg = msg
286+
self.start_time = None
287+
288+
def __enter__(self):
289+
self.start_time = time.time()
290+
291+
def __exit__(self, exc_type, exc_value, exc_tb):
292+
print(self.msg % (time.time() - self.start_time))
293+

0 commit comments

Comments
 (0)