Skip to content

Commit 9b16b81

Browse files
committed
Add support for forwarding incoming sessions over an SSH tunnel
This commit makes it easier for applications to accept incoming tunneled session, connection, or TUN/TAP requests on an SSHServerConnection and forward them over an upstream SSHClientConnection. The methods on SSHServer to accept these requests can now return a SSHClientConnection object to forward the traffic over, instead of having to accept the request, open a corresponding upstream session, and then relay data between the two SSH sessions.
1 parent 00b98eb commit 9b16b81

File tree

5 files changed

+282
-24
lines changed

5 files changed

+282
-24
lines changed

asyncssh/connection.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6305,6 +6305,9 @@ def _process_session_open(self, packet: SSHPacket) -> \
63056305
if not result:
63066306
raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Session refused')
63076307

6308+
if isinstance(result, SSHClientConnection):
6309+
result = self.forward_tunneled_session(result)
6310+
63086311
if isinstance(result, tuple):
63096312
chan, result = result
63106313
else:
@@ -6360,6 +6363,10 @@ def _process_direct_tcpip_open(self, packet: SSHPacket) -> \
63606363
if result is True:
63616364
result = cast(SSHTCPSession[bytes],
63626365
self.forward_connection(dest_host, dest_port))
6366+
elif isinstance(result, SSHClientConnection):
6367+
result = cast(Awaitable[SSHTCPSession[bytes]],
6368+
self.forward_tunneled_connection(
6369+
result, dest_host, dest_port))
63636370

63646371
if isinstance(result, tuple):
63656372
chan, result = result
@@ -6506,6 +6513,10 @@ def _process_direct_streamlocal_at_openssh_dot_com_open(
65066513
if result is True:
65076514
result = cast(SSHUNIXSession[bytes],
65086515
self.forward_unix_connection(dest_path))
6516+
elif isinstance(result, SSHClientConnection):
6517+
result = cast(Awaitable[SSHUNIXSession[bytes]],
6518+
self.forward_tunneled_unix_connection(
6519+
result, dest_path))
65096520

65106521
if isinstance(result, tuple):
65116522
chan, result = result
@@ -6621,10 +6632,14 @@ def _process_tun_at_openssh_dot_com_open(
66216632
result = False
66226633

66236634
if not result:
6624-
raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Connection refused')
6635+
raise ChannelOpenError(OPEN_CONNECT_FAILED,
6636+
'TUN/TAP request refused')
66256637

66266638
if result is True:
66276639
result = cast(SSHTunTapSession, self.forward_tuntap(mode, unit))
6640+
elif isinstance(result, SSHClientConnection):
6641+
result = cast(Awaitable[SSHTunTapSession],
6642+
self.forward_tunneled_tuntap(result, mode, unit))
66286643

66296644
if isinstance(result, tuple):
66306645
chan, result = result
@@ -7179,6 +7194,76 @@ async def open_agent_connection(self) -> \
71797194

71807195
return SSHReader[bytes](session, chan), SSHWriter[bytes](session, chan)
71817196

7197+
async def forward_tunneled_session(
7198+
self, conn: SSHClientConnection) -> SSHServerProcess:
7199+
"""Forward a tunneled session between SSH connections"""
7200+
7201+
async def process_factory(process: SSHServerProcess) -> None:
7202+
"""Return an upstream process used to forward the session"""
7203+
7204+
encoding, errors = process.channel.get_encoding()
7205+
7206+
upstream_process: SSHClientProcess = await conn.create_process(
7207+
command=process.command, subsystem=process.subsystem,
7208+
env=process.env, term_type=process.term_type,
7209+
term_size=process.term_size, term_modes=process.term_modes,
7210+
encoding=encoding, errors=errors, stdin=process.stdin,
7211+
stdout=process.stdout, stderr=process.stderr)
7212+
7213+
await upstream_process.wait_closed()
7214+
7215+
self.logger.info(' Forwarding session via SSH tunnel')
7216+
7217+
return SSHServerProcess(process_factory, None, MIN_SFTP_VERSION, False)
7218+
7219+
async def forward_tunneled_connection(
7220+
self, conn: SSHClientConnection,
7221+
dest_host: str, dest_port: int) -> SSHForwarder:
7222+
"""Forward a tunneled TCP connection between SSH connections"""
7223+
7224+
_, peer = await conn.create_connection(
7225+
cast(SSHTCPSessionFactory[bytes], SSHForwarder),
7226+
dest_host, dest_port)
7227+
7228+
self.logger.info(' Forwarding TCP connection to %s via SSH tunnel',
7229+
(dest_host, dest_port))
7230+
7231+
return SSHForwarder(cast(SSHForwarder, peer))
7232+
7233+
async def forward_tunneled_unix_connection(
7234+
self, conn: SSHClientConnection,
7235+
dest_path: str) -> SSHForwarder:
7236+
"""Forward a tunneled UNIX connection between SSH connections"""
7237+
7238+
_, peer = await conn.create_unix_connection(
7239+
cast(SSHUNIXSessionFactory[bytes], SSHForwarder), dest_path)
7240+
7241+
self.logger.info(' Forwarding UNIX connection to %s via SSH tunnel',
7242+
dest_path)
7243+
7244+
return SSHForwarder(cast(SSHForwarder, peer))
7245+
7246+
async def forward_tunneled_tuntap(
7247+
self, conn: SSHClientConnection,
7248+
mode: int, unit: Optional[int]) -> SSHForwarder:
7249+
"""Forward a TUN/TAP connection between SSH connections"""
7250+
7251+
if mode == SSH_TUN_MODE_POINTTOPOINT:
7252+
create_func = conn.create_tun
7253+
layer = 3
7254+
else:
7255+
create_func = conn.create_tap
7256+
layer = 2
7257+
7258+
transport, peer = await create_func(
7259+
cast(SSHTunTapSessionFactory, SSHForwarder), unit)
7260+
interface = transport.get_extra_info('interface')
7261+
7262+
self.logger.info(' Forwarding layer %d traffic to %s via SSH tunnel',
7263+
layer, interface)
7264+
7265+
return SSHForwarder(cast(SSHForwarder, peer))
7266+
71827267

71837268
class SSHConnectionOptions(Options, Generic[_Options]):
71847269
"""SSH connection options"""

asyncssh/server.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,36 @@
3131

3232
if TYPE_CHECKING:
3333
# pylint: disable=cyclic-import
34-
from .connection import SSHServerConnection, SSHAcceptHandler
34+
from .connection import SSHClientConnection, SSHServerConnection
35+
from .connection import SSHAcceptHandler
3536
from .channel import SSHServerChannel, SSHTCPChannel, SSHUNIXChannel
3637
from .channel import SSHTunTapChannel
3738
from .session import SSHServerSession, SSHTCPSession, SSHUNIXSession
3839
from .session import SSHTunTapSession
3940

4041

41-
_NewSession = \
42-
Union[bool, MaybeAwait['SSHServerSession'], SSHServerSessionFactory,
42+
_NewSession = Union[
43+
bool, 'SSHClientConnection',
44+
MaybeAwait['SSHServerSession'], SSHServerSessionFactory,
4345
Tuple['SSHServerChannel', MaybeAwait['SSHServerSession']],
4446
Tuple['SSHServerChannel', SSHServerSessionFactory]]
45-
_NewTCPSession = \
46-
Union[bool, MaybeAwait['SSHTCPSession'], SSHSocketSessionFactory,
47+
_NewTCPSession = Union[
48+
bool, 'SSHClientConnection',
49+
MaybeAwait['SSHTCPSession'], SSHSocketSessionFactory,
4750
Tuple['SSHTCPChannel', MaybeAwait['SSHTCPSession']],
4851
Tuple['SSHTCPChannel', SSHSocketSessionFactory]]
49-
_NewUNIXSession = \
50-
Union[bool, MaybeAwait['SSHUNIXSession'], SSHSocketSessionFactory,
52+
_NewUNIXSession = Union[
53+
bool, 'SSHClientConnection',
54+
MaybeAwait['SSHUNIXSession'], SSHSocketSessionFactory,
5155
Tuple['SSHUNIXChannel', MaybeAwait['SSHUNIXSession']],
5256
Tuple['SSHUNIXChannel', SSHSocketSessionFactory]]
53-
_NewTunTapSession = \
54-
Union[bool, MaybeAwait['SSHTunTapSession'], SSHSocketSessionFactory,
57+
_NewTunTapSession = Union[
58+
bool, 'SSHClientConnection',
59+
MaybeAwait['SSHTunTapSession'], SSHSocketSessionFactory,
5560
Tuple['SSHTunTapChannel', MaybeAwait['SSHTunTapSession']],
5661
Tuple['SSHTunTapChannel', SSHSocketSessionFactory]]
57-
_NewListener = Union[bool, 'SSHAcceptHandler', MaybeAwait[SSHListener]]
62+
_NewTCPListener = Union[bool, 'SSHAcceptHandler', MaybeAwait[SSHListener]]
63+
_NewUNIXListener = Union[bool, MaybeAwait[SSHListener]]
5864

5965

6066
class SSHServer:
@@ -749,6 +755,11 @@ def connection_requested(self, dest_host: str, dest_port: int,
749755
:exc:`ChannelOpenError` exception with the reason for
750756
the failure.
751757
758+
If the application wishes to tunnel the connection over
759+
another SSH connection, this method should return an
760+
:class:`SSHClientConnection` connected to the desired
761+
tunnel host.
762+
752763
If the application wishes to process the data on the
753764
connection itself, this method should return either an
754765
:class:`SSHTCPSession` object which can be used to process the
@@ -802,7 +813,7 @@ def connection_requested(self, dest_host: str, dest_port: int,
802813
return False # pragma: no cover
803814

804815
def server_requested(self, listen_host: str,
805-
listen_port: int) -> MaybeAwait[_NewListener]:
816+
listen_port: int) -> MaybeAwait[_NewTCPListener]:
806817
"""Handle a request to listen on a TCP/IP address and port
807818
808819
This method is called when a client makes a request to
@@ -864,6 +875,11 @@ def unix_connection_requested(self, dest_path: str) -> _NewUNIXSession:
864875
:exc:`ChannelOpenError` exception with the reason for
865876
the failure.
866877
878+
If the application wishes to tunnel the connection over
879+
another SSH connection, this method should return an
880+
:class:`SSHClientConnection` connected to the desired
881+
tunnel host.
882+
867883
If the application wishes to process the data on the
868884
connection itself, this method should return either an
869885
:class:`SSHUNIXSession` object which can be used to process the
@@ -908,7 +924,7 @@ def unix_connection_requested(self, dest_path: str) -> _NewUNIXSession:
908924
return False # pragma: no cover
909925

910926
def unix_server_requested(self, listen_path: str) -> \
911-
MaybeAwait[_NewListener]:
927+
MaybeAwait[_NewUNIXListener]:
912928
"""Handle a request to listen on a UNIX domain socket
913929
914930
This method is called when a client makes a request to
@@ -958,14 +974,19 @@ def tun_requested(self, unit: Optional[int]) -> _NewTunTapSession:
958974
by the server. Applications wishing to accept such tunnels must
959975
override this method.
960976
961-
To allow standard path forwarding of data on the connection to the
977+
To allow standard forwarding of data on the connection to the
962978
requested TUN device, this method should return `True`.
963979
964980
To reject this request, this method should return `False`
965981
to send back a "Connection refused" response or raise an
966982
:exc:`ChannelOpenError` exception with the reason for
967983
the failure.
968984
985+
If the application wishes to tunnel the data over another
986+
SSH connection, this method should return an
987+
:class:`SSHClientConnection` connected to the desired
988+
tunnel host.
989+
969990
If the application wishes to process the data on the
970991
connection itself, this method should return either an
971992
:class:`SSHTunTapSession` object which can be used to process the
@@ -1016,14 +1037,19 @@ def tap_requested(self, unit: Optional[int]) -> _NewTunTapSession:
10161037
by the server. Applications wishing to accept such tunnels must
10171038
override this method.
10181039
1019-
To allow standard path forwarding of data on the connection to the
1020-
requested TUN device, this method should return `True`.
1040+
To allow standard forwarding of data on the connection to the
1041+
requested TAP device, this method should return `True`.
10211042
10221043
To reject this request, this method should return `False`
10231044
to send back a "Connection refused" response or raise an
10241045
:exc:`ChannelOpenError` exception with the reason for
10251046
the failure.
10261047
1048+
If the application wishes to tunnel the data over another
1049+
SSH connection, this method should return an
1050+
:class:`SSHClientConnection` connected to the desired
1051+
tunnel host.
1052+
10271053
If the application wishes to process the data on the
10281054
connection itself, this method should return either an
10291055
:class:`SSHTunTapSession` object which can be used to process the

tests/test_forward.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,25 @@ async def unix_server_requested(self, listen_path):
229229
return listen_path != 'fail'
230230

231231

232+
class _UpstreamForwardingServer(Server):
233+
"""Server for testing forwarding between SSH connections"""
234+
235+
def __init__(self, upstream_conn):
236+
super().__init__()
237+
238+
self._upstream_conn = upstream_conn
239+
240+
def connection_requested(self, dest_host, dest_port, orig_host, orig_port):
241+
"""Handle a request to create a new connection"""
242+
243+
return self._upstream_conn
244+
245+
def unix_connection_requested(self, dest_path):
246+
"""Handle a request to create a new UNIX domain connection"""
247+
248+
return self._upstream_conn
249+
250+
232251
class _CheckForwarding(ServerTestCase):
233252
"""Utility functions for AsyncSSH forwarding unit tests"""
234253

@@ -296,8 +315,8 @@ class _TestTCPForwarding(_CheckForwarding):
296315
async def start_server(cls):
297316
"""Start an SSH server which supports TCP connection forwarding"""
298317

299-
return (await cls.create_server(
300-
_TCPConnectionServer, authorized_client_keys='authorized_keys'))
318+
return await cls.create_server(
319+
_TCPConnectionServer, authorized_client_keys='authorized_keys')
301320

302321
async def _check_connection(self, conn, dest_host='',
303322
dest_port=7, **kwargs):
@@ -876,6 +895,25 @@ async def test_cancel_forward_remote_port_invalid_unicode(self):
876895

877896
self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE)
878897

898+
@asynctest
899+
async def test_upstream_forward_local_port(self):
900+
"""Test upstream forwarding of a local port"""
901+
902+
def upstream_server():
903+
"""Return a server capable of forwarding between SSH connections"""
904+
905+
return _UpstreamForwardingServer(upstream_conn)
906+
907+
async with self.connect() as upstream_conn:
908+
upstream_listener = await self.create_server(upstream_server)
909+
upstream_port = upstream_listener.get_port()
910+
911+
async with self.connect('127.0.0.1', upstream_port) as conn:
912+
async with conn.forward_local_port('', 0, '', 7) as listener:
913+
await self._check_local_connection(listener.get_port())
914+
915+
upstream_listener.close()
916+
879917
@asynctest
880918
async def test_add_channel_after_close(self):
881919
"""Test opening a connection after a close"""
@@ -963,8 +1001,8 @@ class _TestUNIXForwarding(_CheckForwarding):
9631001
async def start_server(cls):
9641002
"""Start an SSH server which supports UNIX connection forwarding"""
9651003

966-
return (await cls.create_server(
967-
_UNIXConnectionServer, authorized_client_keys='authorized_keys'))
1004+
return await cls.create_server(
1005+
_UNIXConnectionServer, authorized_client_keys='authorized_keys')
9681006

9691007
async def _check_unix_connection(self, conn, dest_path='/echo', **kwargs):
9701008
"""Open a UNIX connection and test if an input line is echoed back"""
@@ -1233,6 +1271,25 @@ async def test_cancel_forward_remote_path_invalid_unicode(self):
12331271

12341272
self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE)
12351273

1274+
@asynctest
1275+
async def test_upstream_forward_local_path(self):
1276+
"""Test upstream forwarding of a local path"""
1277+
1278+
def upstream_server():
1279+
"""Return a server capable of forwarding between SSH connections"""
1280+
1281+
return _UpstreamForwardingServer(upstream_conn)
1282+
1283+
async with self.connect() as upstream_conn:
1284+
upstream_listener = await self.create_server(upstream_server)
1285+
upstream_port = upstream_listener.get_port()
1286+
1287+
async with self.connect('127.0.0.1', upstream_port) as conn:
1288+
async with conn.forward_local_path('local', '/echo'):
1289+
await self._check_local_unix_connection('local')
1290+
1291+
upstream_listener.close()
1292+
12361293

12371294
class _TestAsyncUNIXForwarding(_TestUNIXForwarding):
12381295
"""Unit tests for AsyncSSH UNIX connection forwarding with async return"""
@@ -1253,8 +1310,8 @@ class _TestSOCKSForwarding(_CheckForwarding):
12531310
async def start_server(cls):
12541311
"""Start an SSH server which supports TCP connection forwarding"""
12551312

1256-
return (await cls.create_server(
1257-
_TCPConnectionServer, authorized_client_keys='authorized_keys'))
1313+
return await cls.create_server(
1314+
_TCPConnectionServer, authorized_client_keys='authorized_keys')
12581315

12591316
async def _check_early_error(self, reader, writer, data):
12601317
"""Check errors in the initial SOCKS message"""

0 commit comments

Comments
 (0)