diff --git a/tinyagent/mcp_client.py b/tinyagent/mcp_client.py index eb1919f..bf01f40 100644 --- a/tinyagent/mcp_client.py +++ b/tinyagent/mcp_client.py @@ -1,162 +1,140 @@ import asyncio -import json import logging -from typing import Dict, List, Optional, Any, Tuple, Callable +from typing import Dict, List, Optional, Any, Callable, Union -# Keep your MCPClient implementation unchanged -import asyncio from contextlib import AsyncExitStack - -# MCP core imports from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -# Set up logging logger = logging.getLogger(__name__) class MCPClient: - def __init__(self, logger: Optional[logging.Logger] = None): - self.session = None - self.exit_stack = AsyncExitStack() + """ + Manages multiple connections to MCP servers and aggregates tools. + """ + def __init__( + self, + server_parameters: Union[ + StdioServerParameters, + Dict[str, Any], + List[Union[StdioServerParameters, Dict[str, Any]]] + ], + logger: Optional[logging.Logger] = None, + ): self.logger = logger or logging.getLogger(__name__) - - # Simplified callback system - self.callbacks: List[callable] = [] - - self.logger.debug("MCPClient initialized") + # Normalize to list + if not isinstance(server_parameters, list): + self.server_parameters = [server_parameters] + else: + self.server_parameters = server_parameters + + self.exit_stack = AsyncExitStack() + self.sessions: List[ClientSession] = [] + self.tools: List[Dict[str, Any]] = [] + self._tool_session: Dict[str, ClientSession] = {} + self.callbacks: List[Callable] = [] - def add_callback(self, callback: callable) -> None: + def add_callback(self, callback: Callable) -> None: """ - Add a callback function to the client. - - Args: - callback: A function that accepts (event_name, client, **kwargs) + Add a callback: async or sync func(event_name, client, **kwargs) """ self.callbacks.append(callback) - + async def _run_callbacks(self, event_name: str, **kwargs) -> None: - """ - Run all registered callbacks for an event. - - Args: - event_name: The name of the event - **kwargs: Additional data for the event - """ for callback in self.callbacks: try: - logger.debug(f"Running callback: {callback}") + self.logger.debug(f"Callback: {callback}, event: {event_name}") if asyncio.iscoroutinefunction(callback): - logger.debug(f"Callback is a coroutine function") await callback(event_name, self, **kwargs) else: - # Check if the callback is a class with an async __call__ method if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__): - logger.debug(f"Callback is a class with an async __call__ method") await callback(event_name, self, **kwargs) else: - logger.debug(f"Callback is a regular function") callback(event_name, self, **kwargs) except Exception as e: - logger.error(f"Error in callback for {event_name}: {str(e)}") + self.logger.error(f"Error in callback {event_name}: {e}") - async def connect(self, command: str, args: list[str]): + async def connect(self) -> None: """ - Launches the MCP server subprocess and initializes the client session. - :param command: e.g. "python" or "node" - :param args: list of args to pass, e.g. ["my_server.py"] or ["build/index.js"] + Connect to all MCP servers and gather tools. """ - # Prepare stdio transport parameters - params = StdioServerParameters(command=command, args=args) - # Open the stdio client transport - self.stdio, self.sock_write = await self.exit_stack.enter_async_context( - stdio_client(params) - ) - # Create and initialize the MCP client session - self.session = await self.exit_stack.enter_async_context( - ClientSession(self.stdio, self.sock_write) - ) - await self.session.initialize() + for params in self.server_parameters: + # Build parameters object + if isinstance(params, dict): + params_obj = StdioServerParameters(**params) + elif isinstance(params, StdioServerParameters): + params_obj = params + else: + raise ValueError(f"Invalid server param type: {type(params)}") + + # Enter stdio transport + stdio, sock_write = await self.exit_stack.enter_async_context( + stdio_client(params_obj) + ) + # Enter client session + session = await self.exit_stack.enter_async_context( + ClientSession(stdio, sock_write) + ) + await session.initialize() + self.sessions.append(session) - async def list_tools(self): - resp = await self.session.list_tools() + # Aggregate tools and map to sessions + for session in self.sessions: + resp = await session.list_tools() + for tool in resp.tools: + self.tools.append({ + 'name': tool.name, + 'description': tool.description, + 'schema': getattr(tool, 'schema', None), + }) + self._tool_session[tool.name] = session + + async def list_tools(self) -> None: + """ + Print all available tools. + """ print("Available tools:") - for tool in resp.tools: - print(f" • {tool.name}: {tool.description}") + for t in self.tools: + print(f" • {t['name']}: {t['description']}") - async def call_tool(self, name: str, arguments: dict): + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: """ - Invokes a named tool and returns its raw content list. + Invoke named tool on its session. """ - # Notify tool start await self._run_callbacks("tool_start", tool_name=name, arguments=arguments) - + if name not in self._tool_session: + raise ValueError(f"Tool '{name}' not found.") + session = self._tool_session[name] try: - resp = await self.session.call_tool(name, arguments) - - # Notify tool end - await self._run_callbacks("tool_end", tool_name=name, arguments=arguments, - result=resp.content, success=True) - + resp = await session.call_tool(name, arguments) + await self._run_callbacks( + "tool_end", + tool_name=name, + arguments=arguments, + result=resp.content, + success=True, + ) return resp.content except Exception as e: - # Notify tool end with error - await self._run_callbacks("tool_end", tool_name=name, arguments=arguments, - error=str(e), success=False) + await self._run_callbacks( + "tool_end", + tool_name=name, + arguments=arguments, + error=str(e), + success=False, + ) raise - async def close(self): - """Clean up subprocess and streams.""" - if self.exit_stack: - try: - await self.exit_stack.aclose() - except (RuntimeError, asyncio.CancelledError) as e: - # Log the error but don't re-raise it - self.logger.error(f"Error during client cleanup: {e}") - finally: - # Always reset these regardless of success or failure - self.session = None - self.exit_stack = AsyncExitStack() - -async def run_example(): - """Example usage of MCPClient with proper logging.""" - import sys - from tinyagent.hooks.logging_manager import LoggingManager - - # Create and configure logging manager - log_manager = LoggingManager(default_level=logging.INFO) - log_manager.set_levels({ - 'tinyagent.mcp_client': logging.DEBUG, # Debug for this module - 'tinyagent.tiny_agent': logging.INFO, - }) - - # Configure a console handler - console_handler = logging.StreamHandler(sys.stdout) - log_manager.configure_handler( - console_handler, - format_string='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - level=logging.DEBUG - ) - - # Get module-specific logger - mcp_logger = log_manager.get_logger('tinyagent.mcp_client') - - mcp_logger.debug("Starting MCPClient example") - - # Create client with our logger - client = MCPClient(logger=mcp_logger) - - try: - # Connect to a simple echo server - await client.connect("python", ["-m", "mcp.examples.echo_server"]) - - # List available tools - await client.list_tools() - - # Call the echo tool - result = await client.call_tool("echo", {"message": "Hello, MCP!"}) - mcp_logger.info(f"Echo result: {result}") - - finally: - # Clean up - await client.close() - mcp_logger.debug("Example completed") + async def close(self) -> None: + """ + Cleanup all sessions and transports. + """ + try: + await self.exit_stack.aclose() + except Exception as e: + self.logger.error(f"Cleanup error: {e}") + finally: + self.sessions = [] + self.tools = [] + self._tool_session = {} + self.exit_stack = AsyncExitStack()