Skip to content

Commit 6815d55

Browse files
0.9.11
1 parent bde996c commit 6815d55

File tree

5 files changed

+92
-42
lines changed

5 files changed

+92
-42
lines changed

torchstudio/datasetload.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def __getitem__(self, id):
231231

232232
sshclient = paramiko.SSHClient()
233233
sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
234-
sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
234+
sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=10)
235235
worker_socket = socket.socket()
236236
worker_socket.bind(('localhost', 0))
237237
freeport=worker_socket.getsockname()[1]
@@ -240,7 +240,7 @@ def __getitem__(self, id):
240240
port=freeport
241241

242242
try:
243-
worker_socket = tc.connect((address,port))
243+
worker_socket = tc.connect((address,port),timeout=10)
244244
num_samples=(len(meta_dataset.train()) if train_set else 0) + (len(meta_dataset.valid()) if valid_set else 0)
245245
tc.send_msg(worker_socket, 'NumSamples', tc.encode_ints(num_samples))
246246
tc.send_msg(worker_socket, 'InputTensorsID', tc.encode_ints(input_tensors_id))
@@ -273,7 +273,7 @@ def __getitem__(self, id):
273273

274274
if sshaddress and sshport and username:
275275
try:
276-
forward_tunnel.stop()
276+
del forward_tunnel
277277
except:
278278
pass
279279
try:

torchstudio/modeltrain.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def deepcopy_cpu(value):
7979
device = torch.device(device_id)
8080
pin_memory = True if 'cuda' in device_id else False
8181

82+
if msg_type == 'SetMode':
83+
print("Setting mode...\n", file=sys.stderr)
84+
mode=tc.decode_strings(msg_data)[0]
85+
8286
if msg_type == 'SetTorchScriptModel' and modules_valid:
8387
if msg_data:
8488
print("Setting torchscript model...\n", file=sys.stderr)
@@ -224,9 +228,19 @@ def deepcopy_cpu(value):
224228
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size, shuffle=False, pin_memory=pin_memory)
225229

226230
if msg_type == 'StartTraining' and modules_valid:
231+
scaler = None
232+
if 'cuda' in device_id:
233+
#https://pytorch.org/docs/stable/notes/cuda.html
234+
torch.backends.cuda.matmul.allow_tf32 = True if mode=='TF32' else False
235+
torch.backends.cudnn.allow_tf32 = True
236+
if mode=='FP16':
237+
scaler = torch.cuda.amp.GradScaler()
238+
if mode=='BF16':
239+
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" #https://discuss.pytorch.org/t/bfloat16-has-worse-performance-than-float16-for-conv2d/154373
227240
print("Training... epoch "+str(scheduler.last_epoch)+"\n", file=sys.stderr)
228241

229242
if msg_type == 'TrainOneEpoch' and modules_valid:
243+
230244
#training
231245
model.train()
232246
train_loss = 0
@@ -237,13 +251,23 @@ def deepcopy_cpu(value):
237251
inputs = [tensors[i].to(device) for i in input_tensors_id]
238252
targets = [tensors[i].to(device) for i in output_tensors_id]
239253
optimizer.zero_grad()
240-
outputs = model(*inputs)
241-
outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
242-
loss = 0
243-
for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
244-
loss = loss + criterion(output, target)
245-
loss.backward()
246-
optimizer.step()
254+
255+
with torch.autocast(device_type='cuda' if 'cuda' in device_id else 'cpu', dtype=torch.bfloat16 if mode=='BF16' else torch.float16, enabled=True if '16' in mode else False):
256+
outputs = model(*inputs)
257+
outputs = outputs if type(outputs) is not torch.Tensor else [outputs]
258+
loss = 0
259+
for output, target, criterion in zip(outputs, targets, criteria): #https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440
260+
loss = loss + criterion(output, target)
261+
262+
if scaler:
263+
# Accumulates scaled gradients.
264+
scaler.scale(loss).backward()
265+
scaler.step(optimizer)
266+
scaler.update()
267+
else:
268+
loss.backward()
269+
optimizer.step()
270+
247271
train_loss += loss.item() * inputs[0].size(0)
248272

249273
with torch.set_grad_enabled(False):

torchstudio/pythoncheck.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
args, unknown = parser.parse_known_args()
1616

1717
#check python version first
18-
python_version=(sys.version_info.major,sys.version_info.minor)
19-
min_python_version=(3,7) if args.remote else (3,8) #3.7 required for ordered dicts and stdout/stderr utf8 encoding, 3.8 required for python parsing
18+
python_version=(sys.version_info.major,sys.version_info.minor,sys.version_info.micro)
19+
min_python_version=(3,7,0) if args.remote else (3,8,0) #3.7 required for ordered dicts and stdout/stderr utf8 encoding, 3.8 required for python parsing
2020
if python_version<min_python_version:
21-
print("Error: Python "+str(min_python_version[0])+"."+str(min_python_version[1])+" minimum is required.", file=sys.stderr)
22-
print("This environment has Python "+str(python_version[0])+"."+str(python_version[1])+".", file=sys.stderr)
21+
print("Error: Python "+'.'.join((str(i) for i in min_python_version))+" minimum is required.", file=sys.stderr)
22+
print("This environment has Python "+'.'.join((str(i) for i in python_version))+".", file=sys.stderr)
2323
exit(1)
2424

2525
print("Checking required packages...\n", file=sys.stderr)
@@ -36,15 +36,15 @@
3636
if module is None:
3737
missing_modules.append(module_check)
3838
elif module_check=='torch':
39-
if python_version<(3,8):
39+
if python_version<(3,8,0):
4040
from importlib_metadata import version
4141
else:
4242
from importlib.metadata import version
43-
pytorch_version=tuple(int(i) for i in version('torch').split('.')[:2])
44-
min_pytorch_version=(1,9) #1.9 required for torch.package support, 1.10 preferred for stable torch.fx and profile-directed typing in torchscript
43+
pytorch_version=tuple(int(i) if i.isdigit() else 0 for i in version('torch').split('.')[:3])
44+
min_pytorch_version=(1,9,0) #1.9 required for torch.package support, 1.10 preferred for stable torch.fx and profile-directed typing in torchscript
4545
if pytorch_version<min_pytorch_version:
46-
print("Error: PyTorch "+str(min_pytorch_version[0])+"."+str(min_pytorch_version[1])+" minimum is required.", file=sys.stderr)
47-
print("This environment has PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+".", file=sys.stderr)
46+
print("Error: PyTorch "+'.'.join((str(i) for i in min_pytorch_version))+" minimum is required.", file=sys.stderr)
47+
print("This environment has PyTorch "+'.'.join((str(i) for i in pytorch_version))+".", file=sys.stderr)
4848
exit(1)
4949

5050
if len(missing_modules)>0:
@@ -72,24 +72,32 @@
7272

7373
#finally, list available devices
7474
print("Loading PyTorch...\n", file=sys.stderr)
75-
7675
import torch
7776

7877
print("Listing devices...\n", file=sys.stderr)
79-
8078
devices = {}
81-
devices['cpu'] = {'name': 'CPU', 'pin_memory': False}
79+
devices['cpu'] = {'name': 'CPU', 'pin_memory': False, 'modes': ['FP32']}
8280
for i in range(torch.cuda.device_count()):
83-
devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True}
84-
if pytorch_version>=(1,12):
81+
modes = ['FP32']
82+
#same as torch.cuda.is_bf16_supported() but compatible with PyTorch<1.10, and not limited to current cuda device only
83+
cu_vers = torch.version.cuda
84+
if cu_vers is not None:
85+
cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11
86+
else:
87+
cuda_maj_decide = False
88+
compute_capability=torch.cuda.get_device_properties(torch.cuda.device(i)).major #https://developer.nvidia.com/cuda-gpus
89+
if compute_capability>=8 and cuda_maj_decide: #RTX 3000 and higher
90+
modes+=['TF32','FP16','BF16']
91+
if compute_capability==7: #RTX 2000
92+
modes+=['FP16']
93+
devices['cuda:'+str(i)] = {'name': torch.cuda.get_device_name(i), 'pin_memory': True, 'modes': modes}
94+
if pytorch_version>=(1,12,0):
8595
if torch.backends.mps.is_available():
86-
devices['mps'] = {'name': 'Metal Acceleration', 'pin_memory': False}
96+
devices['mps'] = {'name': 'Metal', 'pin_memory': False, 'modes': ['FP32']}
8797
#other possible devices:
8898
#'hpu' (https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html)
8999
#'dml' (https://docs.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows)
90100
devices_string_list=[]
91101
for id in devices:
92-
devices_string_list.append(devices[id]['name']+" ("+id+")")
93-
print(("Online and functional " if args.remote else "Functional ")+"("+platform.platform()+", Python "+str(python_version[0])+"."+str(python_version[1])+", PyTorch "+str(pytorch_version[0])+"."+str(pytorch_version[1])+", Devices: "+", ".join(devices_string_list)+")");
94-
95-
102+
devices_string_list.append(id+' "'+devices[id]['name']+'" ('+'/'.join(devices[id]['modes'])+')')
103+
print("Ready ("+platform.platform()+", Python "+'.'.join((str(i) for i in python_version))+", PyTorch "+'.'.join((str(i) for i in pytorch_version))+", Devices: "+", ".join(devices_string_list)+")");

torchstudio/sshtunnel.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import io
55

66
# Port forwarding from https://github.com/skyleronken/sshrat/blob/master/tunnels.py
7-
# improved with dynamic local port allocation feedback for reverse tunnel with a null local port
7+
# improved with:
8+
# dynamic local port allocation feedback for reverse tunnel with a null local port
9+
# blocking connections to avoid connection lost with poor cloud servers
10+
# more explicit error messages
811
import threading
912
import socket
1013
import selectors
@@ -322,14 +325,15 @@ def finish(self):
322325
else:
323326
print("Executing remote command...", file=sys.stderr)
324327
stdin, stdout, stderr = sshclient.exec_command("cd TorchStudio&&"+' '.join([args.command]+other_args))
325-
while True:
326-
time.sleep(.1)
328+
329+
while not stdout.channel.exit_status_ready():
330+
time.sleep(.01) #lower CPU usage
327331
if stdout.channel.recv_stderr_ready():
328-
sys.stderr.write(str(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n'),'utf-8'))
332+
sys.stderr.buffer.write(stdout.channel.recv_stderr(1024).replace(b'\r\n',b'\n'))
333+
time.sleep(.01) #for stdout/stderr sync
329334
if stdout.channel.recv_ready():
330-
sys.stdout.write(str(stdout.channel.recv(1024).replace(b'\r\n',b'\n'),'utf-8'))
331-
if stdout.channel.exit_status_ready():
332-
break
335+
sys.stdout.buffer.write(stdout.channel.recv(1024).replace(b'\r\n',b'\n'))
336+
time.sleep(.01) #for stdout/stderr sync
333337
else:
334338
if args.script:
335339
print("Error: no python environment set.", file=sys.stderr)

torchstudio/tcpcodec.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,55 @@ def start_server(server):
1515
conn, addr = server.accept()
1616
return conn
1717

18-
def connect(server_address=None):
18+
def connect(server_address=None, timeout=0):
1919
if server_address==None and len(sys.argv)<3:
20-
print("Missing socket address and port")
20+
print("Missing socket address and port", file=sys.stderr)
2121
exit()
2222

2323
if not server_address:
2424
import argparse
2525
parser = argparse.ArgumentParser()
2626
parser.add_argument("--address", help="server address", type=str, default='localhost')
2727
parser.add_argument("--port", help="local port to which the script must connect", type=int, default=0)
28+
parser.add_argument("--timeout", help="max number of seconds without incoming messages before quitting", type=int, default=0)
2829
args, unknown = parser.parse_known_args()
2930
server_address = (args.address, args.port)
31+
timeout=args.timeout
3032
else:
3133
server_address = (server_address[0], int(server_address[1]))
3234

3335
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
3436
try:
3537
sock.connect(server_address)
3638
except socket.error as serr:
37-
print("Connection error: %s" % str(serr))
39+
print("Connection error: %s" % str(serr), file=sys.stderr)
3840
exit()
39-
41+
if timeout>0:
42+
sock.settimeout(timeout)
4043
return sock
4144

4245
def send_msg(sock, type, data = bytearray()):
4346
type_bytes=bytes(type, 'utf-8')
4447
type_size=len(type_bytes)
4548
msg = struct.pack(f'<B{type_size}sI', type_size, type_bytes, len(data)) + data
46-
sock.sendall(msg)
49+
try:
50+
sock.sendall(msg)
51+
except:
52+
print("Lost connection", file=sys.stderr)
53+
exit()
4754

4855
def recv_msg(sock):
4956
def recvall(sock, n):
5057
data = bytearray()
5158
while len(data) < n:
52-
packet = sock.recv(n - len(data))
59+
try:
60+
packet = sock.recv(n - len(data))
61+
except:
62+
print("Lost connection", file=sys.stderr)
63+
exit()
64+
if len(packet)==0:
65+
print("Lost connection", file=sys.stderr)
66+
exit()
5367
data.extend(packet)
5468
return data
5569
type_size = struct.unpack('<B', recvall(sock, 1))[0]

0 commit comments

Comments
 (0)