9
9
import torchvision .transforms as transforms
10
10
import torch .nn as nn
11
11
from util import *
12
- import neptune
12
+ # import neptune
13
13
14
14
tqdm .monitor_interval = 0
15
15
class Engine (object ):
@@ -173,28 +173,39 @@ def learning(self, model, criterion, train_dataset, val_dataset, optimizer=None)
173
173
batch_size = self .state ['batch_size_test' ], shuffle = False ,
174
174
num_workers = self .state ['workers' ])
175
175
176
+ if self .state ['use_gpu' ]:
177
+ train_loader .pin_memory = True
178
+ val_loader .pin_memory = True
179
+ cudnn .benchmark = False
180
+ model = torch .nn .DataParallel (model , device_ids = self .state ['device_ids' ]).cuda ()
181
+ # model = model.cuda()
182
+
183
+ criterion = criterion .cuda ()
184
+
176
185
# optionally resume from a checkpoint
177
186
if self ._state ('resume' ) is not None :
178
187
if os .path .isfile (self .state ['resume' ]):
179
188
print ("=> loading checkpoint '{}'" .format (self .state ['resume' ]))
180
189
checkpoint = torch .load (self .state ['resume' ])
181
190
self .state ['start_epoch' ] = checkpoint ['epoch' ]
182
191
self .state ['best_score' ] = checkpoint ['best_score' ]
183
- model .load_state_dict (checkpoint ['state_dict' ])
192
+ state_dict = checkpoint ['state_dict' ]
193
+ from collections import OrderedDict
194
+ new_state_dict = OrderedDict ()
195
+
196
+ for k , v in state_dict .items ():
197
+ if 'module' not in k :
198
+ k = 'module.' + k
199
+ else :
200
+ k = k .replace ('features.module.' , 'module.features.' )
201
+ new_state_dict [k ]= v
202
+
203
+ model .load_state_dict (new_state_dict )
184
204
print ("=> loaded checkpoint '{}' (epoch {})"
185
205
.format (self .state ['evaluate' ], checkpoint ['epoch' ]))
186
206
else :
187
207
print ("=> no checkpoint found at '{}'" .format (self .state ['resume' ]))
188
208
189
-
190
- if self .state ['use_gpu' ]:
191
- train_loader .pin_memory = True
192
- val_loader .pin_memory = True
193
- cudnn .benchmark = False
194
- model = torch .nn .DataParallel (model , device_ids = self .state ['device_ids' ]).cuda ()
195
-
196
- criterion = criterion .cuda ()
197
-
198
209
if self .state ['evaluate' ]:
199
210
self .validate (val_loader , model , criterion )
200
211
return
@@ -254,7 +265,7 @@ def train(self, data_loader, model, criterion, optimizer, epoch):
254
265
self .on_start_batch (True , model , criterion , data_loader , optimizer )
255
266
256
267
if self .state ['use_gpu' ]:
257
- self .state ['target' ] = self .state ['target' ].cuda (async = True )
268
+ self .state ['target' ] = self .state ['target' ].cuda ()
258
269
259
270
self .on_forward (True , model , criterion , data_loader , optimizer )
260
271
@@ -290,7 +301,7 @@ def validate(self, data_loader, model, criterion):
290
301
self .on_start_batch (False , model , criterion , data_loader )
291
302
292
303
if self .state ['use_gpu' ]:
293
- self .state ['target' ] = self .state ['target' ].cuda (async = True )
304
+ self .state ['target' ] = self .state ['target' ].cuda ()
294
305
295
306
self .on_forward (False , model , criterion , data_loader )
296
307
@@ -440,7 +451,7 @@ def on_forward(self, training, model, criterion, data_loader, optimizer=None, di
440
451
# compute output
441
452
self .state ['output' ] = model (feature_var , inp_var )
442
453
self .state ['loss' ] = criterion (self .state ['output' ], target_var )
443
-
454
+
444
455
optimizer .zero_grad ()
445
456
self .state ['loss' ].backward ()
446
457
nn .utils .clip_grad_norm_ (model .parameters (), max_norm = 10.0 )
@@ -459,7 +470,7 @@ def on_forward(self, training, model, criterion, data_loader, optimizer=None, di
459
470
self .state ['loss' ] = criterion (self .state ['output' ], target_var )
460
471
with open ('out.csv' , 'a' ) as f :
461
472
f .write (to_csv (self .state ['output' ].detach ().cpu ().numpy ()[0 ]) + '\n ' )
462
- torch .cuda .empty_cache ()
473
+ torch .cuda .empty_cache ()
463
474
464
475
465
476
def on_start_batch (self , training , model , criterion , data_loader , optimizer = None , display = True ):
@@ -472,4 +483,4 @@ def on_start_batch(self, training, model, criterion, data_loader, optimizer=None
472
483
self .state ['feature' ] = input [0 ]
473
484
self .state ['out' ] = input [1 ]
474
485
self .state ['input' ] = input [2 ]
475
-
486
+
0 commit comments