@@ -181,6 +181,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
181
181
debug_data_loader: (Optional)Debugging flag for data loader.
182
182
train_async_data_loader_queue_size: (Optional) Size of train async data loader queue.
183
183
val_async_data_loader_queue_size: (Optional) Size of val async data loader queue.
184
+ use_gpu: Whether to use the GPU for training. Defaults to True.
184
185
"""
185
186
186
187
input_shapes = Param (Params ._dummy (), 'input_shapes' , 'input layer shapes' )
@@ -189,10 +190,6 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
189
190
train_minibatch_fn = Param (Params ._dummy (), 'train_minibatch_fn' ,
190
191
'functions that construct the minibatch train function for torch' )
191
192
192
- inmemory_cache_all = Param (Params ._dummy (), 'inmemory_cache_all' ,
193
- 'Cache the data in memory for training and validation.' ,
194
- typeConverter = TypeConverters .toBoolean )
195
-
196
193
num_gpus = Param (Params ._dummy (), 'num_gpus' ,
197
194
'Number of gpus per process, default to 1 when CUDA is available in the backend, otherwise 0.' )
198
195
@@ -266,14 +263,14 @@ def __init__(self,
266
263
profiler = None ,
267
264
debug_data_loader = False ,
268
265
train_async_data_loader_queue_size = None ,
269
- val_async_data_loader_queue_size = None ):
266
+ val_async_data_loader_queue_size = None ,
267
+ use_gpu = True ):
270
268
271
269
super (TorchEstimator , self ).__init__ ()
272
270
self ._setDefault (loss_constructors = None ,
273
271
input_shapes = None ,
274
272
train_minibatch_fn = None ,
275
273
transformation_fn = None ,
276
- inmemory_cache_all = False ,
277
274
num_gpus = None ,
278
275
logger = None ,
279
276
log_every_n_steps = 50 ,
@@ -315,12 +312,6 @@ def setLossConstructors(self, value):
315
312
def getLossConstructors (self ):
316
313
return self .getOrDefault (self .loss_constructors )
317
314
318
- def setInMemoryCacheAll (self , value ):
319
- return self ._set (inmemory_cache_all = value )
320
-
321
- def getInMemoryCacheAll (self ):
322
- return self .getOrDefault (self .inmemory_cache_all )
323
-
324
315
def setNumGPUs (self , value ):
325
316
return self ._set (num_gpus = value )
326
317
0 commit comments