Skip to content

Commit 537016b

Browse files
authored
Update engine.py
1 parent ad24d62 commit 537016b

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

engine.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision.transforms as transforms
1010
import torch.nn as nn
1111
from util import *
12-
import neptune
12+
# import neptune
1313

1414
tqdm.monitor_interval = 0
1515
class Engine(object):
@@ -173,28 +173,39 @@ def learning(self, model, criterion, train_dataset, val_dataset, optimizer=None)
173173
batch_size=self.state['batch_size_test'], shuffle=False,
174174
num_workers=self.state['workers'])
175175

176+
if self.state['use_gpu']:
177+
train_loader.pin_memory = True
178+
val_loader.pin_memory = True
179+
cudnn.benchmark = False
180+
model = torch.nn.DataParallel(model, device_ids=self.state['device_ids']).cuda()
181+
# model = model.cuda()
182+
183+
criterion = criterion.cuda()
184+
176185
# optionally resume from a checkpoint
177186
if self._state('resume') is not None:
178187
if os.path.isfile(self.state['resume']):
179188
print("=> loading checkpoint '{}'".format(self.state['resume']))
180189
checkpoint = torch.load(self.state['resume'])
181190
self.state['start_epoch'] = checkpoint['epoch']
182191
self.state['best_score'] = checkpoint['best_score']
183-
model.load_state_dict(checkpoint['state_dict'])
192+
state_dict =checkpoint['state_dict']
193+
from collections import OrderedDict
194+
new_state_dict = OrderedDict()
195+
196+
for k, v in state_dict.items():
197+
if 'module' not in k:
198+
k = 'module.'+k
199+
else:
200+
k = k.replace('features.module.', 'module.features.')
201+
new_state_dict[k]=v
202+
203+
model.load_state_dict(new_state_dict)
184204
print("=> loaded checkpoint '{}' (epoch {})"
185205
.format(self.state['evaluate'], checkpoint['epoch']))
186206
else:
187207
print("=> no checkpoint found at '{}'".format(self.state['resume']))
188208

189-
190-
if self.state['use_gpu']:
191-
train_loader.pin_memory = True
192-
val_loader.pin_memory = True
193-
cudnn.benchmark = False
194-
model = torch.nn.DataParallel(model, device_ids=self.state['device_ids']).cuda()
195-
196-
criterion = criterion.cuda()
197-
198209
if self.state['evaluate']:
199210
self.validate(val_loader, model, criterion)
200211
return
@@ -254,7 +265,7 @@ def train(self, data_loader, model, criterion, optimizer, epoch):
254265
self.on_start_batch(True, model, criterion, data_loader, optimizer)
255266

256267
if self.state['use_gpu']:
257-
self.state['target'] = self.state['target'].cuda(async=True)
268+
self.state['target'] = self.state['target'].cuda()
258269

259270
self.on_forward(True, model, criterion, data_loader, optimizer)
260271

@@ -290,7 +301,7 @@ def validate(self, data_loader, model, criterion):
290301
self.on_start_batch(False, model, criterion, data_loader)
291302

292303
if self.state['use_gpu']:
293-
self.state['target'] = self.state['target'].cuda(async=True)
304+
self.state['target'] = self.state['target'].cuda()
294305

295306
self.on_forward(False, model, criterion, data_loader)
296307

@@ -440,7 +451,7 @@ def on_forward(self, training, model, criterion, data_loader, optimizer=None, di
440451
# compute output
441452
self.state['output'] = model(feature_var, inp_var)
442453
self.state['loss'] = criterion(self.state['output'], target_var)
443-
454+
444455
optimizer.zero_grad()
445456
self.state['loss'].backward()
446457
nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
@@ -459,7 +470,7 @@ def on_forward(self, training, model, criterion, data_loader, optimizer=None, di
459470
self.state['loss'] = criterion(self.state['output'], target_var)
460471
with open('out.csv', 'a') as f:
461472
f.write(to_csv(self.state['output'].detach().cpu().numpy()[0]) + '\n')
462-
torch.cuda.empty_cache()
473+
torch.cuda.empty_cache()
463474

464475

465476
def on_start_batch(self, training, model, criterion, data_loader, optimizer=None, display=True):
@@ -472,4 +483,4 @@ def on_start_batch(self, training, model, criterion, data_loader, optimizer=None
472483
self.state['feature'] = input[0]
473484
self.state['out'] = input[1]
474485
self.state['input'] = input[2]
475-
486+

0 commit comments

Comments
 (0)