diff --git a/tinyagent/mcp_client.py b/tinyagent/mcp_client.py index eb1919f..059d7a9 100644 --- a/tinyagent/mcp_client.py +++ b/tinyagent/mcp_client.py @@ -1,57 +1,67 @@ +++ b/tinyagent/mcp_client.py import asyncio import json import logging +from typing import Optional, List, Coroutine, Any from typing import Dict, List, Optional, Any, Tuple, Callable -# Keep your MCPClient implementation unchanged -import asyncio -from contextlib import AsyncExitStack - -# MCP core imports -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client - # Set up logging -logger = logging.getLogger(__name__) - class MCPClient: - def __init__(self, logger: Optional[logging.Logger] = None): + + # We'll hold each context manager separately rather than in one stack: + self._stdio_ctx = None + self._session_ctx = None + self.stdio = None + self.sock_write = None + self.session = None self.session = None - self.exit_stack = AsyncExitStack() - self.logger = logger or logging.getLogger(__name__) - # Simplified callback system - self.callbacks: List[callable] = [] + self.callbacks: List[Callable[..., Coroutine[Any,Any,Any]]] = [] - self.logger.debug("MCPClient initialized") + for callback in self.callbacks: + try: + params = StdioServerParameters(command=command, args=args) - def add_callback(self, callback: callable) -> None: - """ - Add a callback function to the client. - - Args: - callback: A function that accepts (event_name, client, **kwargs) - """ - self.callbacks.append(callback) - - async def _run_callbacks(self, event_name: str, **kwargs) -> None: - """ - Run all registered callbacks for an event. - - Args: - event_name: The name of the event - **kwargs: Additional data for the event - """ + # 1) enter stdio_client context + self._stdio_ctx = stdio_client(params) + try: + self.stdio, self.sock_write = await self._stdio_ctx.__aenter__() + except Exception as e: + self.logger.error(f"Failed to start stdio_client: {e}") + raise + + # 2) enter ClientSession context + self._session_ctx = ClientSession(self.stdio, self.sock_write) + try: + self.session = await self._session_ctx.__aenter__() + await self.session.initialize() + except Exception as e: + self.logger.error(f"Failed to initialize MCP session: {e}") + # make sure we unwind the stdio context if session init fails + await self._stdio_ctx.__aexit__(None, None, None) + raise for callback in self.callbacks: try: - logger.debug(f"Running callback: {callback}") - if asyncio.iscoroutinefunction(callback): - logger.debug(f"Callback is a coroutine function") - await callback(event_name, self, **kwargs) - else: - # Check if the callback is a class with an async __call__ method - if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__): - logger.debug(f"Callback is a class with an async __call__ method") + # 1) teardown session + if self._session_ctx is not None: + try: + await self._session_ctx.__aexit__(None, None, None) + except Exception as e: + self.logger.error(f"Error closing MCP session: {e}") + finally: + self.session = None + self._session_ctx = None + + # 2) teardown stdio + if self._stdio_ctx is not None: + try: + await self._stdio_ctx.__aexit__(None, None, None) + except Exception as e: + self.logger.error(f"Error closing stdio client: {e}") + finally: + self.stdio = None + self.sock_write = None + self._stdio_ctx = None await callback(event_name, self, **kwargs) else: logger.debug(f"Callback is a regular function")