Skip to content

Commit c38b279

Browse files
0.9.16
1 parent 0ebbf2a commit c38b279

File tree

5 files changed

+45
-58
lines changed

5 files changed

+45
-58
lines changed

torchstudio/datasets/genericloader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,16 @@ def to_tensors(self, path:str):
130130
tensors = []
131131
if path.endswith('.jpg') or path.endswith('.jpeg') or path.endswith('.png') or path.endswith('.webp') or path.endswith('.tif') or path.endswith('.tiff'):
132132
img=Image.open(path)
133-
for i in range(img.n_frames):
133+
frames = 1
134+
if hasattr( img, 'n_frames'):
135+
frames = img.n_frames
136+
for i in range(frames):
134137
if img.mode=='1' or img.mode=='L' or img.mode=='P':
135138
tensors.append(torch.from_numpy(np.array(img, dtype=np.uint8)))
136139
else:
137140
trans=torchvision.transforms.ToTensor()
138141
tensors.append(trans(img))
139-
if i<(img.n_frames-1):
142+
if i<(frames-1):
140143
img.seek(img.tell()+1)
141144

142145
if path.endswith('.mp3') or path.endswith('.wav') or path.endswith('.ogg') or path.endswith('.flac'):

torchstudio/graphdraw.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,8 @@
561561
filtered_nodes[id]['inputs'].remove(input)
562562
filtered_nodes[id]['inputs'].append(sub_input)
563563
filtered_nodes[id]['input_shape'][sub_input]=nodes[input]['output_shape']
564-
del filtered_nodes[input]
564+
if input in filtered_nodes:
565+
del filtered_nodes[input]
565566
#del non-referenced getitems
566567
nodes=copy.deepcopy(filtered_nodes)
567568
for id, node in nodes.items():
@@ -599,8 +600,6 @@
599600
output_shape=filtered_nodes[input]['output_shape']
600601
if input in node['input_shape']:
601602
output_shape=node['input_shape'][input]
602-
if batch==1:
603-
output_shape=('N,' if output_shape else 'N')+output_shape
604603
graph.edge(input,id," "+output_shape.replace(',','\u00d7')) #replace comma by multiplication sign
605604

606605
if legend==1:

torchstudio/modelbuild.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def level_trace(root):
163163
output_shape=output_shape[:-1]
164164
else:
165165
output_dtype=str(node.meta['tensor_meta'].dtype)
166-
output_shape=','.join([str(i) for i in list(node.meta['tensor_meta'].shape)[1:]])
166+
output_shape=','.join([str(i) for i in list(node.meta['tensor_meta'].shape)])
167167

168168
if node.op == 'placeholder':
169169
node_type='input'

torchstudio/pythoncheck.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,12 @@
5858
else:
5959
#install ssh support if necessary
6060
if not importlib.util.find_spec("paramiko"):
61-
if importlib.util.find_spec("conda"):
62-
print("Installing Paramiko (SSH for Python) using Conda...", file=sys.stderr)
63-
import conda.cli.python_api as Conda
64-
Conda.run_command(Conda.Commands.INSTALL,['paramiko','-c','conda-forge', "--quiet", "--quiet"]) #2 times quiet to remove all verbose
65-
elif importlib.util.find_spec("pip"):
61+
if importlib.util.find_spec("pip"):
6662
print("Installing Paramiko (SSH for Python) using Pip...", file=sys.stderr)
6763
import subprocess
6864
subprocess.run([sys.executable, "-m", "pip", "install", "paramiko", "--quiet", "--quiet"]) #2 times quiet to remove all verbose
6965
else:
70-
print("Error: Conda or Pip is required to install Paramiko (SSH for Python).", file=sys.stderr)
66+
print("Error: Pip is required to install Paramiko (SSH for Python).", file=sys.stderr)
7167
exit(1)
7268

7369
#finally, list available devices

torchstudio/pythoninstall.py

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,60 @@
1-
#otherwise conda install may fail
2-
del __file__
3-
__package__=None
4-
__spec__=None
5-
61
import sys
72
import importlib
83
import importlib.util
94
import argparse
5+
import subprocess
106
parser = argparse.ArgumentParser()
11-
parser.add_argument("--channel", help="pytorch channel", type=str, default='pytorch')
127
parser.add_argument("--cuda", help="install nvidia gpu support", action="store_true", default=False)
138
parser.add_argument("--package", help="install specific package", action='append', nargs='+', default=[])
149
args, unknown = parser.parse_known_args()
1510

16-
if importlib.util.find_spec("conda") is None:
17-
print("Error: A Conda environment is required to install the required packages.", file=sys.stderr)
11+
if importlib.util.find_spec("pip") is None:
12+
print("Error: Pip is required to install the required packages.", file=sys.stderr)
1813
exit()
1914

20-
import conda.cli.python_api as Conda
21-
22-
#increase rows (from default 20 when no terminal is found) to display all parallel packages downloads at once
23-
from tqdm import tqdm
24-
init_source=tqdm.__init__
25-
def init_patch(self, **kwargs):
26-
kwargs['ncols']=80
27-
kwargs['nrows']=80
28-
init_source(self, **kwargs)
29-
tqdm.__init__=init_patch
30-
3115
if not args.package:
32-
#https://edcarp.github.io/introduction-to-conda-for-data-scientists/03-using-packages-and-channels/index.html#alternative-syntax-for-installing-packages-from-specific-channels
33-
conda_install=f"pytorch torchvision torchaudio torchtext"
16+
#first install python-graphviz as it only exist as a conda package, and conda is recommended before pip: https://www.anaconda.com/blog/using-pip-in-a-conda-environment
17+
if importlib.util.find_spec("conda") is None:
18+
print("Error: Conda is required to install the graphviz package.", file=sys.stderr)
19+
exit()
20+
else:
21+
print("Downloading and installing the graphviz package...")
22+
print("")
23+
result = subprocess.run([sys.executable, "-m", "conda", "install", "-y", "python-graphviz", "-c", "conda-forge"])
24+
if result.returncode != 0:
25+
exit(result.returncode)
26+
print("")
27+
28+
pip_install="torch torchvision torchaudio torchtext"
3429
if (sys.platform.startswith('win') or sys.platform.startswith('linux')):
3530
if args.cuda:
3631
print("Checking the latest supported CUDA version...")
37-
highest_cuda_version=(11,6) #highest supported cuda version for PyTorch 1.12
32+
highest_cuda_version=118 #11.8 highest supported cuda version for PyTorch 2.0
3833
import requests
3934
try:
40-
pytorch_repo = requests.get("https://anaconda.org/"+args.channel+"/pytorch/files")
35+
pytorch_repo = requests.get("https://download.pytorch.org/whl/torch")
4136
except:
4237
print("Could not retrieve the latest supported CUDA version")
4338
else:
4439
import re
45-
regex_request=re.compile("cuda([0-9]+.[0-9]+)")
40+
regex_request=re.compile("cu([0-9]+)")
4641
results = re.findall(regex_request, pytorch_repo.text)
47-
highest_cuda_version=(11,6)
42+
highest_cuda_version=118
4843
for cuda_string in results:
49-
cuda_version=tuple(int(i) for i in cuda_string.split('.'))
44+
cuda_version=int(cuda_string)
5045
if cuda_version > highest_cuda_version:
5146
highest_cuda_version = cuda_version
52-
highest_cuda_string='.'.join([str(value) for value in highest_cuda_version])
47+
highest_cuda_string=str(highest_cuda_version)[:2]+"."+str(highest_cuda_version)[2:]
5348
print("Using CUDA "+highest_cuda_string)
5449
print("")
55-
conda_install+=" pytorch-cuda="+highest_cuda_string+" -c "+args.channel+" -c nvidia"
56-
else:
57-
conda_install+=" cpuonly -c "+args.channel
58-
else:
59-
conda_install+=" -c "+args.channel
60-
print(f"Downloading and installing {args.channel} packages...")
50+
pip_install+=" --index-url https://download.pytorch.org/whl/cu"+str(highest_cuda_version)
51+
52+
print("Downloading and installing pytorch packages...")
6153
print("")
62-
# https://stackoverflow.com/questions/41767340/using-conda-install-within-a-python-script
63-
(stdout_str, stderr_str, return_code_int) = Conda.run_command(Conda.Commands.INSTALL,conda_install.split(),use_exception_handler=True,stdout=sys.stdout,stderr=sys.stderr)
64-
if return_code_int!=0:
65-
exit(return_code_int)
54+
55+
result = subprocess.run([sys.executable, "-m", "pip", "install"]+pip_install.split())
56+
if result.returncode != 0:
57+
exit(result.returncode)
6658
print("")
6759

6860
# onnx required for onnx export
@@ -73,16 +65,13 @@ def init_patch(self, **kwargs):
7365
# python-graphviz required by torchstudio graph
7466
# paramiko required for ssh connections (+updated cffi required on intel mac)
7567
# pysoundfile required by torchaudio datasets: https://pytorch.org/audio/stable/backend.html#soundfile-backend
76-
conda_install="onnx datasets scipy pandas matplotlib-base python-graphviz paramiko pysoundfile"
77-
if sys.platform.startswith('darwin'):
78-
conda_install+=" cffi"
68+
pip_install="onnx datasets scipy pandas matplotlib paramiko pysoundfile"
7969

8070
else:
81-
conda_install=" ".join(args.package[0])
71+
pip_install=" ".join(args.package[0])
8272

83-
print("Downloading and installing conda-forge packages...")
73+
print("Downloading and installing additional packages...")
8474
print("")
85-
conda_install+=" -c conda-forge"
86-
(stdout_str, stderr_str, return_code_int) = Conda.run_command(Conda.Commands.INSTALL,conda_install.split(),use_exception_handler=True,stdout=sys.stdout,stderr=sys.stderr)
87-
if return_code_int!=0:
88-
exit(return_code_int)
75+
result = subprocess.run([sys.executable, "-m", "pip", "install"]+pip_install.split())
76+
if result.returncode != 0:
77+
exit(result.returncode)

0 commit comments

Comments
 (0)