Skip to content

Commit ff78b5e

Browse files
0.9.12
1 parent 6815d55 commit ff78b5e

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

torchstudio/datasets/randomgenerator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class RandomGenerator(Dataset):
1010
Size of the dataset (number of samples)
1111
tensors:
1212
A list of tuples defining tensor properties: shape, type, range
13-
All properties are optionals. Defaults are null, float, [0,1]
13+
All properties are optionals. Defaults are null, torch.float, [0,1]
1414
"""
1515

1616
def __init__(self, size:int=256, tensors=[(3,64,64), (int,[0,9])]):
@@ -29,12 +29,12 @@ def __getitem__(self, idx):
2929
sample = []
3030
for properties in self.tensors:
3131
shape=[]
32-
dtype=float
32+
dtype=torch.float
3333
drange=[0,1]
3434
for property in properties:
3535
if type(property)==int:
3636
shape.append(property)
37-
elif inspect.isclass(property):
37+
elif type(property)==type or type(property)==torch.dtype:
3838
dtype=property
3939
elif type(property) is list:
4040
drange=property

torchstudio/modeltrain.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def deepcopy_cpu(value):
237237
scaler = torch.cuda.amp.GradScaler()
238238
if mode=='BF16':
239239
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" #https://discuss.pytorch.org/t/bfloat16-has-worse-performance-than-float16-for-conv2d/154373
240+
train_type=None
241+
if mode=='FP16':
242+
train_type=torch.float16
243+
if mode=='BF16':
244+
train_type=torch.bfloat16
240245
print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr)
241246

242247
if msg_type == 'TrainOneEpoch' and modules_valid:
@@ -252,7 +257,7 @@ def deepcopy_cpu(value):
252257
targets = [tensors[i].to(device) for i in output_tensors_id]
253258
optimizer.zero_grad()
254259

255-
with torch.autocast(device_type='cuda' if 'cuda' in device_id else 'cpu', dtype=torch.bfloat16 if mode=='BF16' else torch.float16, enabled=True if '16' in mode else False):
260+
with torch.autocast(device_type='cuda' if 'cuda' in device_id else 'cpu', dtype=train_type, enabled=True if train_type else False):
256261
outputs = model(*inputs)
257262
outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
258263
loss = 0

0 commit comments

Comments
 (0)