Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ec2-instance-connect] add more cleanup to websockets #9346

Open
wants to merge 1 commit into
base: v2
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 162 additions & 38 deletions awscli/customizations/ec2instanceconnect/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,16 @@ def has_data_to_read(self):
return False

def read(self, amt) -> bytes:
return sys.stdin.buffer.read1(amt)
try:
data = sys.stdin.buffer.read1(amt)
# Empty data indicates EOF (pipe closed)
if not data:
logger.debug("Stdin returned empty data (EOF). Input is closed.")
raise InputClosedError()
return data
except (OSError, IOError) as e:
logger.debug(f"IO error reading from stdin: {str(e)}")
raise InputClosedError()

def write(self, data):
sys.stdout.buffer.write(data)
Expand All @@ -88,38 +97,70 @@ def close(self):

class WindowsStdinStdoutIO(StdinStdoutIO):
def has_data_to_read(self):
return True
# For Windows, we can't reliably check stdin without blocking
# We'll rely on the read method to detect when input is closed
# by catching EOF errors in the calling code
try:
if sys.stdin.closed:
return False
return True
except (OSError, ValueError, IOError):
return False


class TCPSocketIO(BaseWebsocketIO):
def __init__(self, conn):
self.conn = conn
self._is_closed = False

def has_data_to_read(self):
return True
if self._is_closed:
return False

# Use select with a timeout to check if there's data
try:
read_ready, _, _ = select.select([self.conn], [], [], _SELECT_TIMEOUT)
return bool(read_ready)
except (OSError, ValueError, socket.error):
self._is_closed = True
return False

def read(self, amt) -> bytes:
data = self.conn.recv(amt)
# In listener mode use can CTRL+C during host verification that kills the client TCP connect,
# when this happens we are able to successfully disconnect because has_data_to_read always return true.
# This will check if data is empty and if yes then raise InputCloseError
#
# recv() relies on the underlying system call which returns empty bytes when the connection is closed.
# Linux: https://manpages.debian.org/bullseye/manpages-dev/recv.2.en.html
# Windows: https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recv
if not data:
try:
data = self.conn.recv(amt)
# In listener mode use can CTRL+C during host verification that kills the client TCP connect,
# when this happens we are able to successfully disconnect because has_data_to_read always return true.
# This will check if data is empty and if yes then raise InputCloseError
#
# recv() relies on the underlying system call which returns empty bytes when the connection is closed.
# Linux: https://manpages.debian.org/bullseye/manpages-dev/recv.2.en.html
# Windows: https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recv
if not data:
self._is_closed = True
raise InputClosedError()
return data
except (OSError, socket.error):
self._is_closed = True
raise InputClosedError()
return data

def write(self, data):
self.conn.sendall(data)
if self._is_closed:
raise InputClosedError()
try:
self.conn.sendall(data)
except (OSError, socket.error):
self._is_closed = True
raise InputClosedError()

def close(self):
try:
self.conn.close()
# On Windows, we could receive an OSError if the tcp conn is already closed.
except OSError:
pass
if not self._is_closed:
self._is_closed = True
try:
self.conn.shutdown(socket.SHUT_RDWR)
self.conn.close()
# On Windows, we could receive an OSError if the tcp conn is already closed.
except OSError:
pass


class Websocket:
Expand Down Expand Up @@ -217,9 +258,25 @@ def write_data_from_input(self):
try:
# Start writing data to the websocket connection and block current thread.
self._write_data_from_input()
except Exception as e:
logger.error(f"Unexpected error in write_data_from_input: {str(e)}")
finally:
# Make sure to clean up on exit
logger.debug("Exiting write_data_from_input, cleaning up")
self.close()

# If we're a stdin/stdout websocket and input was closed,
# ensure the process exits cleanly
if isinstance(self.websocketio, StdinStdoutIO) or isinstance(self.websocketio, WindowsStdinStdoutIO):
logger.debug("Stdin/stdout websocket closed, exiting process")
# This is a bit drastic but necessary to ensure the process exits
# when stdin is closed in pipe mode
import os
import signal
# Send SIGTERM to ourselves to initiate clean shutdown
# This is more reliable than sys.exit() which can be caught
os.kill(os.getpid(), signal.SIGTERM)

if self._exception:
raise self._exception

Expand All @@ -231,25 +288,52 @@ def close(self):

def _write_data_from_input(self):
while not self._shutdown_event.is_set():
# Check if websocket is still valid
if not self._websocket:
logger.debug('Websocket is closed or invalid. Exiting write loop.')
self.close()
return

# Wait until there's some data to read
if not self.websocketio.has_data_to_read():
time.sleep(self._WAIT_INTERVAL_FOR_INPUT)
continue
try:
if not self.websocketio.has_data_to_read():
time.sleep(self._WAIT_INTERVAL_FOR_INPUT)
continue
except Exception as e:
logger.debug(f'Error checking for data: {str(e)}. Shutting down websocket.')
self.close()
return

try:
data = self.websocketio.read(self._MAX_BYTES_PER_FRAME)
# Skip empty data (shouldn't happen, but as a safeguard)
if not data:
logger.debug('Received empty data. Skipping frame.')
continue
except InputClosedError as e:
logger.debug('Input closed. Shutting down websocket.')
self.close()
return
except Exception as e:
logger.debug(f'Error reading data: {str(e)}. Shutting down websocket.')
self.close()
return

try:
self._websocket.send_frame(
opcode=Opcode.BINARY,
payload=data,
on_complete=self._on_send_frame_complete_data,
)
# Block until send_frame on_complete
self._send_frame_results_queue.get()
# Block until send_frame on_complete with a timeout
try:
result = self._send_frame_results_queue.get(timeout=5.0)
if result and hasattr(result, 'exception') and result.exception:
raise result.exception
except Exception as e:
logger.debug(f'Timeout or error waiting for frame completion: {str(e)}')
self.close()
return
except RuntimeError as e:
crt_exceptions = [
"AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT",
Expand All @@ -261,8 +345,15 @@ def _write_data_from_input(self):
f"Received exception when sending websocket frame: {e.args}"
)
self.close()
return
else:
logger.debug(f"Unhandled runtime error: {e.args}")
self.close()
raise e
except Exception as e:
logger.debug(f'Unexpected error sending frame: {str(e)}')
self.close()
return

def _on_connection(self, data: OnConnectionSetupData) -> None:
request_id_header = [
Expand Down Expand Up @@ -354,17 +445,33 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
for _, web_socket in self._inflight_futures_and_websockets:
# Close the websocket handlers.
web_socket.close()
logger.debug("Shutting down WebsocketManager")
# First set RUNNING flag to false so any remaining loops exit
self.RUNNING.set()

# Close all websocket handlers
for future, web_socket in self._inflight_futures_and_websockets:
try:
web_socket.close()
# Try to cancel any still-running futures
if not future.done():
future.cancel()
except Exception as e:
logger.debug(f"Error closing websocket: {str(e)}")

# Close server socket if exists
if self._socket:
try:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
# On Windows, if the socket is already closed, we will get an OSError.
except OSError:
pass
self._executor.shutdown()

# Shutdown executor with a timeout
logger.debug("Shutting down executor")
self._executor.shutdown(wait=False)
logger.debug("WebsocketManager shutdown complete")

# Used to break out of while loop in tests.
RUNNING = threading.Event()
Expand All @@ -375,11 +482,20 @@ def run(self):
websocketio = (
WindowsStdinStdoutIO() if is_windows else StdinStdoutIO()
)
future = self._open_websocket_connection(
Websocket(websocketio, websocket_id=None)
)
# Block until the future completes.
future.result()
web_socket = Websocket(websocketio, websocket_id=None)
try:
future = self._open_websocket_connection(web_socket)
# Block until the future completes.
future.result()
except WebsocketException as e:
logger.error(f"Websocket error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
finally:
# Make sure everything is closed and we can exit
web_socket.close()
# Force shutdown the executor to ensure the process can exit
self._executor.shutdown(wait=False)
else:
self._listen_on_port()

Expand Down Expand Up @@ -424,13 +540,21 @@ def _listen_on_port(self):
)

def _open_websocket_connection(self, web_socket):
presigned_url = self._eice_request_signer.get_presigned_url()
web_socket.connect(presigned_url, self._user_agent)
try:
presigned_url = self._eice_request_signer.get_presigned_url()
web_socket.connect(presigned_url, self._user_agent)

future = self._executor.submit(web_socket.write_data_from_input)
# Submit the task with a done callback to clean up resources
future = self._executor.submit(web_socket.write_data_from_input)

self._inflight_futures_and_websockets.append((future, web_socket))
return future
# Store for cleanup
self._inflight_futures_and_websockets.append((future, web_socket))

return future
except Exception as e:
logger.error(f"Failed to open websocket connection: {str(e)}")
web_socket.close()
raise

def _print_tcp_conn_closed(self, web_socket):
def _on_done_callback(future):
Expand Down