From 5e01a421e46a22c0c6d37e926841339986591a81 Mon Sep 17 00:00:00 2001 From: askdevai-bot Date: Thu, 29 May 2025 20:38:23 -0400 Subject: [PATCH] Fix MCPClient cleanup: separate async contexts --- tinyagent/mcp_client.py | 184 +++++++++++++--------------------------- 1 file changed, 58 insertions(+), 126 deletions(-) diff --git a/tinyagent/mcp_client.py b/tinyagent/mcp_client.py index eb1919f..87a0e2f 100644 --- a/tinyagent/mcp_client.py +++ b/tinyagent/mcp_client.py @@ -1,162 +1,94 @@ import asyncio import json import logging -from typing import Dict, List, Optional, Any, Tuple, Callable +from typing import Optional, List, Callable, Any, Coroutine -# Keep your MCPClient implementation unchanged -import asyncio -from contextlib import AsyncExitStack - -# MCP core imports -from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.session import ClientSession, StdioServerParameters -# Set up logging logger = logging.getLogger(__name__) class MCPClient: def __init__(self, logger: Optional[logging.Logger] = None): + # We'll hold each context manager separately rather than in one stack: + self._stdio_ctx = None + self._session_ctx = None + self.stdio = None + self.sock_write = None self.session = None - self.exit_stack = AsyncExitStack() + self.logger = logger or logging.getLogger(__name__) - # Simplified callback system - self.callbacks: List[callable] = [] - - self.logger.debug("MCPClient initialized") + self.callbacks: List[Callable[..., Coroutine[Any, Any, Any]]] = [] + + self.logger.debug('MCPClient initialized') - def add_callback(self, callback: callable) -> None: - """ - Add a callback function to the client. - - Args: - callback: A function that accepts (event_name, client, **kwargs) - """ + def add_callback(self, callback: Callable[..., Coroutine[Any, Any, Any]]) -> None: self.callbacks.append(callback) - - async def _run_callbacks(self, event_name: str, **kwargs) -> None: - """ - Run all registered callbacks for an event. - - Args: - event_name: The name of the event - **kwargs: Additional data for the event - """ - for callback in self.callbacks: + + async def _run_callbacks(self, event: str, **kwargs): + for cb in self.callbacks: try: - logger.debug(f"Running callback: {callback}") - if asyncio.iscoroutinefunction(callback): - logger.debug(f"Callback is a coroutine function") - await callback(event_name, self, **kwargs) - else: - # Check if the callback is a class with an async __call__ method - if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__): - logger.debug(f"Callback is a class with an async __call__ method") - await callback(event_name, self, **kwargs) - else: - logger.debug(f"Callback is a regular function") - callback(event_name, self, **kwargs) + await cb(event, **kwargs) except Exception as e: - logger.error(f"Error in callback for {event_name}: {str(e)}") + self.logger.error(f'Error in callback {cb}: {e}') - async def connect(self, command: str, args: list[str]): - """ - Launches the MCP server subprocess and initializes the client session. - :param command: e.g. "python" or "node" - :param args: list of args to pass, e.g. ["my_server.py"] or ["build/index.js"] - """ - # Prepare stdio transport parameters + async def connect(self, command: str, args: list): params = StdioServerParameters(command=command, args=args) - # Open the stdio client transport - self.stdio, self.sock_write = await self.exit_stack.enter_async_context( - stdio_client(params) - ) - # Create and initialize the MCP client session - self.session = await self.exit_stack.enter_async_context( - ClientSession(self.stdio, self.sock_write) - ) - await self.session.initialize() + # 1) enter stdio_client context + self._stdio_ctx = stdio_client(params) + try: + self.stdio, self.sock_write = await self._stdio_ctx.__aenter__() + except Exception as e: + self.logger.error(f'Failed to start stdio_client: {e}') + raise + + # 2) enter ClientSession context + self._session_ctx = ClientSession(self.stdio, self.sock_write) + try: + self.session = await self._session_ctx.__aenter__() + await self.session.initialize() + except Exception as e: + self.logger.error(f'Failed to initialize MCP session: {e}') + # make sure we unwind the stdio context if session init fails + await self._stdio_ctx.__aexit__(None, None, None) + raise async def list_tools(self): resp = await self.session.list_tools() - print("Available tools:") + print('Available tools:') for tool in resp.tools: - print(f" • {tool.name}: {tool.description}") + print(f' • {tool.name}: {tool.description}') async def call_tool(self, name: str, arguments: dict): - """ - Invokes a named tool and returns its raw content list. - """ - # Notify tool start - await self._run_callbacks("tool_start", tool_name=name, arguments=arguments) - + await self._run_callbacks('tool_start', tool_name=name, arguments=arguments) try: resp = await self.session.call_tool(name, arguments) - - # Notify tool end - await self._run_callbacks("tool_end", tool_name=name, arguments=arguments, - result=resp.content, success=True) - + await self._run_callbacks('tool_end', tool_name=name, arguments=arguments, result=resp.content, success=True) return resp.content except Exception as e: - # Notify tool end with error - await self._run_callbacks("tool_end", tool_name=name, arguments=arguments, - error=str(e), success=False) + await self._run_callbacks('tool_end', tool_name=name, arguments=arguments, error=str(e), success=False) raise async def close(self): - """Clean up subprocess and streams.""" - if self.exit_stack: + """Clean up subprocess and streams, one context at a time.""" + # 1) teardown session + if self._session_ctx is not None: try: - await self.exit_stack.aclose() - except (RuntimeError, asyncio.CancelledError) as e: - # Log the error but don't re-raise it - self.logger.error(f"Error during client cleanup: {e}") + await self._session_ctx.__aexit__(None, None, None) + except Exception as e: + self.logger.error(f'Error closing MCP session: {e}') finally: - # Always reset these regardless of success or failure self.session = None - self.exit_stack = AsyncExitStack() + self._session_ctx = None -async def run_example(): - """Example usage of MCPClient with proper logging.""" - import sys - from tinyagent.hooks.logging_manager import LoggingManager - - # Create and configure logging manager - log_manager = LoggingManager(default_level=logging.INFO) - log_manager.set_levels({ - 'tinyagent.mcp_client': logging.DEBUG, # Debug for this module - 'tinyagent.tiny_agent': logging.INFO, - }) - - # Configure a console handler - console_handler = logging.StreamHandler(sys.stdout) - log_manager.configure_handler( - console_handler, - format_string='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - level=logging.DEBUG - ) - - # Get module-specific logger - mcp_logger = log_manager.get_logger('tinyagent.mcp_client') - - mcp_logger.debug("Starting MCPClient example") - - # Create client with our logger - client = MCPClient(logger=mcp_logger) - - try: - # Connect to a simple echo server - await client.connect("python", ["-m", "mcp.examples.echo_server"]) - - # List available tools - await client.list_tools() - - # Call the echo tool - result = await client.call_tool("echo", {"message": "Hello, MCP!"}) - mcp_logger.info(f"Echo result: {result}") - - finally: - # Clean up - await client.close() - mcp_logger.debug("Example completed") + # 2) teardown stdio + if self._stdio_ctx is not None: + try: + await self._stdio_ctx.__aexit__(None, None, None) + except Exception as e: + self.logger.error(f'Error closing stdio client: {e}') + finally: + self.stdio = None + self.sock_write = None + self._stdio_ctx = None