@@ -247,8 +247,8 @@ def main():
247
247
help = 'learning rate (default: 1.0)' )
248
248
parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
249
249
help = 'Learning rate step gamma (default: 0.7)' )
250
- parser .add_argument ('--accel' , action = 'store_true' ,
251
- help = 'use accelerator' )
250
+ parser .add_argument ('--no- accel' , action = 'store_true' ,
251
+ help = 'disables accelerator' )
252
252
parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
253
253
help = 'quickly check a single pass' )
254
254
parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
@@ -258,16 +258,13 @@ def main():
258
258
parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
259
259
help = 'For Saving the current Model' )
260
260
args = parser .parse_args ()
261
+
262
+ use_accel = not args .no_accel and torch .accelerator .is_available ()
261
263
262
264
torch .manual_seed (args .seed )
263
265
264
- if args .accel and not torch .accelerator .is_available ():
265
- print ("ERROR: accelerator is not available, try running on CPU" )
266
- sys .exit (1 )
267
- if not args .accel and torch .accelerator .is_available ():
268
- print ("WARNING: accelerator is available, run with --accel to enable it" )
269
266
270
- if args . accel :
267
+ if use_accel :
271
268
device = torch .accelerator .current_accelerator ()
272
269
else :
273
270
device = torch .device ("cpu" )
@@ -276,12 +273,12 @@ def main():
276
273
277
274
train_kwargs = {'batch_size' : args .batch_size }
278
275
test_kwargs = {'batch_size' : args .test_batch_size }
279
- if device == "cuda" :
280
- cuda_kwargs = {'num_workers' : 1 ,
276
+ if use_accel :
277
+ accel_kwargs = {'num_workers' : 1 ,
281
278
'pin_memory' : True ,
282
279
'shuffle' : True }
283
- train_kwargs .update (cuda_kwargs )
284
- test_kwargs .update (cuda_kwargs )
280
+ train_kwargs .update (accel_kwargs )
281
+ test_kwargs .update (accel_kwargs )
285
282
286
283
train_dataset = APP_MATCHER ('../data' , train = True , download = True )
287
284
test_dataset = APP_MATCHER ('../data' , train = False )
0 commit comments