Skip to content

Commit bf36e45

Browse files
torchstudio 0.9.8
see 0.9.8 release changelog for more details
1 parent cf6bb6c commit bf36e45

12 files changed

+132
-70
lines changed

torchstudio/LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 Robin Lobel
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

torchstudio/metrics/accuracy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def update(self, preds, target):
2727
raise ValueError("prediction and target have different shapes or aren't compatible with multiclass prediction")
2828

2929
self.num_correct += torch.sum(correct)
30-
self.num_samples += correct.shape[0]
30+
self.num_samples += torch.tensor(correct.shape[0])
3131

3232
def compute(self):
3333
if self.num_samples == 0:
3434
raise ValueError("Accuracy must have at least one sample before it can be computed.")
35-
return self.num_correct / self.num_samples
35+
return self.num_correct.float() / self.num_samples.float()
3636

3737
def reset(self):
3838
self.num_correct = 0

torchstudio/metrics/fscore.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def update(self, preds, target):
2525
if self.normalize:
2626
preds=F.softmax(preds, dim=1)
2727
tp = torch.sum(torch.eq(torch.argmax(preds, dim=1), target).view(-1))
28-
tpfp = tp.shape[0]
29-
tpfn = tp.shape[0]
28+
tpfp = torch.tensor(tp.shape[0])
29+
tpfn = torch.tensor(tp.shape[0])
3030
elif preds.shape==target.shape:
3131
if self.normalize:
3232
preds=F.sigmoid(preds)
@@ -40,8 +40,8 @@ def update(self, preds, target):
4040
self.tpfp += tpfp
4141
self.tp += tp
4242
def compute(self):
43-
precision = self.tp / self.tpfp
44-
recall = self.tp / self.tpfn
43+
precision = self.tp.float() / self.tpfp.float()
44+
recall = self.tp.float() / self.tpfn.float()
4545
fscore = (1.0+self.beta_square)*(precision*recall)/(self.beta_square*precision+recall)
4646
return fscore
4747

torchstudio/metrics/precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def update(self, preds, target):
2222
if self.normalize:
2323
preds=F.softmax(preds, dim=1)
2424
tp = torch.sum(torch.eq(torch.argmax(preds, dim=1), target).view(-1))
25-
tpfp = tp.shape[0]
25+
tpfp = torch.tensor(tp.shape[0])
2626
elif preds.shape==target.shape:
2727
if self.normalize:
2828
preds=F.sigmoid(preds)
@@ -34,7 +34,7 @@ def update(self, preds, target):
3434
self.tpfp += tpfp
3535
self.tp += tp
3636
def compute(self):
37-
return self.tp / self.tpfp
37+
return self.tp.float() / self.tpfp.float()
3838

3939
def reset(self):
4040
self.tpfp = 0

torchstudio/metrics/recall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def update(self, preds, target):
2222
if self.normalize:
2323
preds=F.softmax(preds, dim=1)
2424
tp = torch.sum(torch.eq(torch.argmax(preds, dim=1), target).view(-1))
25-
tpfn = tp.shape[0]
25+
tpfn = torch.tensor(tp.shape[0])
2626
elif preds.shape==target.shape:
2727
if self.normalize:
2828
preds=F.sigmoid(preds)
@@ -34,7 +34,7 @@ def update(self, preds, target):
3434
self.tpfn += tpfn
3535
self.tp += tp
3636
def compute(self):
37-
return self.tp / self.tpfn
37+
return self.tp.float() / self.tpfn.float()
3838

3939
def reset(self):
4040
self.tpfn = 0

torchstudio/modelbuild.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def level_trace(root):
234234
for tensor in output_tensors:
235235
metric.append("Accuracy")
236236

237-
tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,0]))
237+
tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([128,0,100,1,1,1]))
238238
tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step']))
239239

240240
if msg_type == 'Exit':

torchstudio/modeltrain.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import io
1212
import tempfile
1313
from tqdm.auto import tqdm
14+
from collections.abc import Iterable
1415

1516

1617
class CachedDataset(Dataset):
@@ -44,6 +45,16 @@ def __getitem__(self, id):
4445
sample=self.index[id]
4546
return sample
4647

48+
def deepcopy_cpu(value):
49+
if isinstance(value, torch.Tensor):
50+
value = value.to("cpu")
51+
return value
52+
elif isinstance(value, dict):
53+
return {k: deepcopy_cpu(v) for k, v in value.items()}
54+
elif isinstance(value, Iterable):
55+
return type(value)(deepcopy_cpu(v) for v in value)
56+
else:
57+
return value
4758

4859
modules_valid=True
4960

@@ -65,19 +76,23 @@ def __getitem__(self, id):
6576
pin_memory = True if 'cuda' in device_id else False
6677

6778
if msg_type == 'SetTorchScriptModel' and modules_valid:
68-
print("Setting torchscript model...\n", file=sys.stderr)
69-
buffer=io.BytesIO(msg_data)
70-
model = torch.jit.load(buffer, map_location=device)
79+
if msg_data:
80+
print("Setting torchscript model...\n", file=sys.stderr)
81+
buffer=io.BytesIO(msg_data)
82+
model = torch.jit.load(buffer)
7183

7284
if msg_type == 'SetPackageModel' and modules_valid:
73-
print("Setting package model...\n", file=sys.stderr)
74-
buffer=io.BytesIO(msg_data)
75-
model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl', map_location=device)
85+
if msg_data:
86+
print("Setting package model...\n", file=sys.stderr)
87+
buffer=io.BytesIO(msg_data)
88+
model = torch.package.PackageImporter(buffer).load_pickle('model', 'model.pkl')
7689

7790
if msg_type == 'SetModelState' and modules_valid:
7891
if model is not None:
79-
buffer=io.BytesIO(msg_data)
80-
model.load_state_dict(torch.load(buffer,map_location=device))
92+
if msg_data:
93+
buffer=io.BytesIO(msg_data)
94+
model.load_state_dict(torch.load(buffer))
95+
model.to(device)
8196

8297
if msg_type == 'SetLossCodes' and modules_valid:
8398
print("Setting loss code...\n", file=sys.stderr)
@@ -116,9 +131,11 @@ def __getitem__(self, id):
116131
tc.send_msg(app_socket, 'TrainingError')
117132
else:
118133
optimizer = optimizer_env['optimizer']
134+
119135
if msg_type == 'SetOptimizerState' and modules_valid:
120-
buffer=io.BytesIO(msg_data)
121-
optimizer.load_state_dict(torch.load(buffer,map_location=device))
136+
if msg_data:
137+
buffer=io.BytesIO(msg_data)
138+
optimizer.load_state_dict(torch.load(buffer))
122139

123140
if msg_type == 'SetSchedulerCode' and modules_valid:
124141
print("Setting scheduler code...\n", file=sys.stderr)
@@ -131,9 +148,11 @@ def __getitem__(self, id):
131148
scheduler = scheduler_env['scheduler']
132149

133150
if msg_type == 'SetHyperParametersValues' and modules_valid: #set other hyperparameters values
134-
batch_size, shuffle, epochs, early_stop = tc.decode_ints(msg_data)
135-
early_stop=True if early_stop==1 else False
151+
batch_size, shuffle, epochs, early_stop, monitor_metric, restore_best = tc.decode_ints(msg_data)
136152
shuffle=True if shuffle==1 else False
153+
early_stop=True if early_stop==1 else False
154+
monitor_metric=True if monitor_metric==1 else False
155+
restore_best=True if restore_best==1 else False
137156

138157
if msg_type == 'StartTrainingServer' and modules_valid:
139158
print("Caching...\n", file=sys.stderr)
@@ -267,11 +286,11 @@ def __getitem__(self, id):
267286
tc.send_msg(app_socket, 'ValidationMetric', tc.encode_floats(valid_metrics))
268287

269288
buffer=io.BytesIO()
270-
torch.save(model.state_dict(), buffer)
289+
torch.save(deepcopy_cpu(model.state_dict()), buffer)
271290
tc.send_msg(app_socket, 'ModelState', buffer.getvalue())
272291

273292
buffer=io.BytesIO()
274-
torch.save(optimizer.state_dict(), buffer)
293+
torch.save(deepcopy_cpu(optimizer.state_dict()), buffer)
275294
tc.send_msg(app_socket, 'OptimizerState', buffer.getvalue())
276295

277296
tc.send_msg(app_socket, 'Trained')
@@ -280,7 +299,7 @@ def __getitem__(self, id):
280299
if train_bar is not None:
281300
train_bar.bar_format='{desc} epoch {n_fmt} | {remaining} left |{rate_fmt}\n\n'
282301
else:
283-
train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n')
302+
train_bar = tqdm(total=epochs, desc='Training...', bar_format='{desc} epoch '+str(scheduler.last_epoch)+'\n\n', initial=scheduler.last_epoch-1)
284303
train_bar.update(1)
285304

286305
if msg_type == 'StopTraining' and modules_valid:

torchstudio/pythoncheck.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
print("Checking Python version...\n", file=sys.stderr)
44

5+
import platform
56
import argparse
67
import importlib
78

@@ -23,8 +24,8 @@
2324
checked_modules = ["torch", "torchvision"]
2425
required_packages = ["pytorch", "torchvision"]
2526
if not args.remote:
26-
checked_modules += ["torchaudio", "matplotlib", "graphviz"]
27-
required_packages += ["torchaudio", "matplotlib-base", "python-graphviz"]
27+
checked_modules += ["torchaudio", "torchtext", "matplotlib", "graphviz"]
28+
required_packages += ["torchaudio", "torchtext", "matplotlib-base", "python-graphviz"]
2829
missing_modules = []
2930
for module_check in checked_modules:
3031
module = importlib.util.find_spec(module_check)
@@ -46,6 +47,7 @@
4647
#warn about missing modules
4748
print("Error: Missing Python modules:", file=sys.stderr)
4849
print(*missing_modules, sep = " ", file=sys.stderr)
50+
print("", file=sys.stderr)
4951
print("The following packages are required:", file=sys.stderr)
5052
print(' '.join(required_packages), file=sys.stderr)
5153
exit(1)
@@ -75,12 +77,15 @@
7577
devices['cpu'] = {'name': 'CPU', 'pin_memory': False}
7678
for i in range(torch.cuda.device_count()):
7779
devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True}
80+
if pytorch_version>=(1,12):
81+
if torch.backends.mps.is_available():
82+
devices['mps'] = {'name': 'Metal Acceleration', 'pin_memory': False}
7883
#other possible devices:
7984
#'hpu' (https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html)
8085
#'dml' (https://docs.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows)
8186
devices_string_list=[]
8287
for id in devices:
8388
devices_string_list.append(devices[id]['name']+" ("+id+")")
84-
print(("Online and functional " if args.remote else "Functional environment ")+"(Python "+str(python_version[0])+"."+str(python_version[1])+", PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+", Devices: "+", ".join(devices_string_list)+")");
89+
print(("Online and functional " if args.remote else "Functional ")+"("+platform.platform()+", Python "+str(python_version[0])+"."+str(python_version[1])+", PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+", Devices: "+", ".join(devices_string_list)+")");
8590

8691

torchstudio/pythoninstall.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import importlib.util
44
import argparse
55
parser = argparse.ArgumentParser()
6+
parser.add_argument("--base", help="install base packages", action="store_true", default=False)
67
parser.add_argument("--gpu", help="install nvidia gpu support", action="store_true", default=False)
8+
parser.add_argument("--package", help="install specific package", action='append', nargs='+', default=[])
79
args, unknown = parser.parse_known_args()
810

911
if importlib.util.find_spec("conda") is None:
@@ -12,37 +14,32 @@
1214

1315
import conda.cli.python_api as Conda
1416

15-
# datasets(+huggingface_hub) required by hugging face hub
16-
# scipy required by torchvision: Caltech ImageNet SBD SVHN datasets and Inception v3 GoogLeNet models
17-
# pandas required by the dataset tutorial: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
18-
# matplotlib-base required by torchstudio renderers
19-
# python-graphviz required by torchstudio graph
20-
# paramiko required for ssh connections
21-
# pysoundfile required on windows by torchaudio: https://pytorch.org/audio/stable/backend.html#soundfile-backend
22-
if sys.platform.startswith('win'):
23-
if args.gpu:
24-
conda_install="pytorch torchvision torchaudio cudatoolkit=11.3 datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
25-
else:
26-
conda_install="pytorch torchvision torchaudio cpuonly datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
27-
elif sys.platform.startswith('darwin'):
28-
# force a pytorch/mkl version, because pytorch 1.10.2+ depends on mkl 2022 which is incompatible with Rosetta 2 in M1 macs, and update cffi 1.15.0-py39hc55c11b_1 to 1.15.0-py39he338e87_0+ to avoid paramiko error
29-
conda_install="pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 mkl==2021.4.0 datasets scipy pandas matplotlib-base python-graphviz paramiko cffi"
30-
elif sys.platform.startswith('linux'):
31-
if args.gpu:
32-
conda_install="pytorch torchvision torchaudio cudatoolkit=11.3 datasets scipy pandas matplotlib-base python-graphviz paramiko"
33-
else:
34-
conda_install="pytorch torchvision torchaudio cpuonly datasets scipy pandas matplotlib-base python-graphviz paramiko"
35-
else:
36-
print("Error: Unsupported platform.", file=sys.stderr)
37-
print("Windows, macOS or Linux is required.", file=sys.stderr)
38-
exit()
17+
conda_install=""
18+
if args.base:
19+
# scipy required by torchvision: Caltech ImageNet SBD SVHN datasets and Inception v3 GoogLeNet models
20+
# pandas required by the dataset tutorial: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
21+
# matplotlib-base required by torchstudio renderers
22+
# python-graphviz required by torchstudio graph
23+
# paramiko required for ssh connections
24+
# pysoundfile required by torchaudio datasets: https://pytorch.org/audio/stable/backend.html#soundfile-backend
25+
# datasets(+huggingface_hub) is required by hugging face hub
26+
conda_install="pytorch torchvision torchaudio torchtext scipy pandas matplotlib-base python-graphviz paramiko pysoundfile datasets"
27+
if (sys.platform.startswith('win') or sys.platform.startswith('linux')) and not args.gpu:
28+
conda_install+=" cpuonly"
29+
if sys.platform.startswith('darwin'):
30+
conda_install+=" cffi"
31+
32+
if args.package:
33+
if args.base:
34+
conda_install+=" "
35+
conda_install+=" ".join(args.package[0])
3936

40-
print("Downloading and installing PyTorch and additional packages:")
37+
print("Downloading and installing conda packages:")
4138
print(conda_install)
4239
print("")
4340

44-
# channels: pytorch for pytorch torchvision torchaudio, nvidia for cudatoolkit=11.1 on Linux, huggingface for datasets(+huggingface_hub), conda-forge for everything else except anaconda for python-graphviz
45-
conda_install+=" -c pytorch -c nvidia -c huggingface -c conda-forge -c anaconda"
41+
# channels: pytorch for pytorch torchvision torchaudio, conda-forge for everything else
42+
conda_install+=" -c pytorch -c conda-forge"
4643

4744
# https://stackoverflow.com/questions/41767340/using-conda-install-within-a-python-script
4845
(stdout_str, stderr_str, return_code_int) = Conda.run_command(Conda.Commands.INSTALL,conda_install.split(),stdout=sys.stdout,stderr=sys.stderr)

torchstudio/pythonparse.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@ def gather_parameters(node):
2929
elif inspect.isclass(param.default) or inspect.isfunction(param.default):
3030
params.append(param.default.__module__+'.'+param.default.__name__)
3131
else:
32-
params.append(repr(param.default))
32+
value=repr(param.default)
33+
if "<class '" in value:
34+
value=value.replace("<class '","")
35+
value=value.replace("'>","")
36+
params.append(value)
3337
return params
3438

3539
def gather_objects(module):
3640
objects=[]
3741
for name, obj in inspect.getmembers(module):
38-
if (inspect.isclass(obj) or inspect.isfunction(obj)) and not hasattr(obj, '_fields') and obj.__module__.find('.utils')==-1: #filter unwanted torch objects
42+
if ((inspect.isclass(obj) and hasattr(obj, '__mro__') and ("torch.nn.modules.module.Module" in str(obj.__mro__) or "torch.utils.data.dataset.Dataset" in str(obj.__mro__))) or inspect.isfunction(obj)): #filter unwanted torch objects
3943
object={}
4044
object['type']='class' if inspect.isclass(obj) else 'function'
4145
object['name']=name
@@ -357,13 +361,17 @@ def scan_folder(path):
357361
for i in range(len(objects_batch)):
358362
objects_batch[i]['code']=code #set whole source code for each object, as we don't know the dependencies
359363
objects.extend(objects_batch)
364+
else:
365+
print("Error parsing code:", error_msg, "\n", file=sys.stderr)
360366
else:
361367
#parse module
362368
error_msg, module = safe_exec(importlib.import_module,(path,))
363369
if error_msg is None and module is not None:
364370
objects=gather_objects(module)
365371
for i, object in enumerate(objects):
366372
objects[i]['code']=generate_code(path,object) #generate inherited source code
373+
else:
374+
print("Error parsing module:", error_msg, "\n", file=sys.stderr)
367375

368376
tc.send_msg(app_socket, 'ObjectsBegin', tc.encode_strings(path))
369377
for object in objects:

0 commit comments

Comments
 (0)