Skip to content

Commit a514e8f

Browse files
0.13
1 parent ff78b5e commit a514e8f

File tree

5 files changed

+27
-16
lines changed

5 files changed

+27
-16
lines changed

torchstudio/modeltrain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def deepcopy_cpu(value):
184184

185185
sshclient = paramiko.SSHClient()
186186
sshclient.set_missing_host_key_policy(paramiko.AutoAddPolicy())
187-
sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=5)
187+
sshclient.connect(hostname=sshaddress, port=int(sshport), username=username, password=password, pkey=pkey, timeout=10)
188188

189189
reverse_tunnel = sshtunnel.Tunnel(sshclient, sshtunnel.ReverseTunnel, 'localhost', 0, 'localhost', int(address[1]))
190190
address[1]=str(reverse_tunnel.lport)
@@ -319,6 +319,7 @@ def deepcopy_cpu(value):
319319

320320
buffer=io.BytesIO()
321321
torch.save(deepcopy_cpu(model.state_dict()), buffer)
322+
print("Training... epoch "+str(scheduler.last_epoch-1)+' | (save)\n\n', file=sys.stderr)
322323
tc.send_msg(app_socket, 'ModelState', buffer.getvalue())
323324

324325
buffer=io.BytesIO()

torchstudio/pythoninstall.cmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ if EXIST "%pythonpath%" (
227227
)
228228
)
229229

230-
set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin
230+
set PATH=%PATH%;%pythonpath%;%pythonpath%\Library\mingw-w64\bin;%pythonpath%\Library\bin;%pythonpath%\bin
231231
"%pythonpath%\python" -u -B -X utf8 -m torchstudio.pythoninstall --channel %channel% %cuda% %packages%
232232
if ERRORLEVEL 1 (
233233
echo. 1>&2

torchstudio/pythoninstall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@
3838
highest_cuda_string='.'.join([str(value) for value in highest_cuda_version])
3939
print("Using CUDA "+highest_cuda_string)
4040
print("")
41-
conda_install+=" fastchan::cudatoolkit="+highest_cuda_string
41+
conda_install+=f" {args.channel}::pytorch-cuda="+highest_cuda_string+" -c nvidia"
4242
else:
4343
conda_install+=f" {args.channel}::cpuonly"
4444
print(f"Downloading and installing {args.channel} packages...")
4545
print("")
46-
conda_install+=" -k" #adding fastchan for cudatoolkit, faster than the conda-forge channel
46+
conda_install+=" -k" #allow insecure ssl connections
4747
# https://stackoverflow.com/questions/41767340/using-conda-install-within-a-python-script
4848
(stdout_str, stderr_str, return_code_int) = Conda.run_command(Conda.Commands.INSTALL,conda_install.split(),use_exception_handler=True,stdout=sys.stdout,stderr=sys.stderr)
4949
if return_code_int!=0:

torchstudio/sshtunnel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def finish(self):
258258

259259
print("Connecting to remote server...", file=sys.stderr)
260260
try:
261-
sshclient.connect(hostname=args.sshaddress, port=args.sshport, username=args.username, password=args.password, pkey=paramiko.RSAKey.from_private_key_file(args.keyfile) if args.keyfile else None, timeout=5)
261+
sshclient.connect(hostname=args.sshaddress, port=args.sshport, username=args.username, password=args.password, pkey=paramiko.RSAKey.from_private_key_file(args.keyfile) if args.keyfile else None, timeout=10)
262262
except:
263263
print("Error: could not connect to remote server", file=sys.stderr)
264264
exit(1)

torchstudio/tcpcodec.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,25 @@ def connect(server_address=None, timeout=0):
4242
sock.settimeout(timeout)
4343
return sock
4444

45-
def send_msg(sock, type, data = bytearray()):
46-
type_bytes=bytes(type, 'utf-8')
45+
def send_msg(sock, data_type, data = bytearray()):
46+
def sendall(sock, data):
47+
while len(data) >0:
48+
try:
49+
ret = sock.send(data[:1048576]) #1MB chunks
50+
except:
51+
print("Lost connection (send timeout)", file=sys.stderr)
52+
exit()
53+
if ret == 0:
54+
print("Lost connection (send null)", file=sys.stderr)
55+
exit()
56+
else:
57+
data=data[ret:]
58+
59+
type_bytes=bytes(data_type, 'utf-8')
4760
type_size=len(type_bytes)
4861
msg = struct.pack(f'<B{type_size}sI', type_size, type_bytes, len(data)) + data
49-
try:
50-
sock.sendall(msg)
51-
except:
52-
print("Lost connection", file=sys.stderr)
53-
exit()
62+
63+
sendall(sock, msg)
5464

5565
def recv_msg(sock):
5666
def recvall(sock, n):
@@ -59,17 +69,17 @@ def recvall(sock, n):
5969
try:
6070
packet = sock.recv(n - len(data))
6171
except:
62-
print("Lost connection", file=sys.stderr)
72+
print("Lost connection (receive timeout)", file=sys.stderr)
6373
exit()
6474
if len(packet)==0:
65-
print("Lost connection", file=sys.stderr)
75+
print("Lost connection (receive null)", file=sys.stderr)
6676
exit()
6777
data.extend(packet)
6878
return data
6979
type_size = struct.unpack('<B', recvall(sock, 1))[0]
70-
type = struct.unpack(f'<{type_size}s', recvall(sock, type_size))[0]
80+
data_type = struct.unpack(f'<{type_size}s', recvall(sock, type_size))[0]
7181
datalen = struct.unpack('<I', recvall(sock, 4))[0]
72-
return str(type, 'utf-8'), recvall(sock, datalen)
82+
return str(data_type, 'utf-8'), recvall(sock, datalen)
7383

7484

7585

0 commit comments

Comments
 (0)