Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 100 additions & 122 deletions tinyagent/mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,162 +1,140 @@
import asyncio
import json
import logging
from typing import Dict, List, Optional, Any, Tuple, Callable
from typing import Dict, List, Optional, Any, Callable, Union

# Keep your MCPClient implementation unchanged
import asyncio
from contextlib import AsyncExitStack

# MCP core imports
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

# Set up logging
logger = logging.getLogger(__name__)

class MCPClient:
def __init__(self, logger: Optional[logging.Logger] = None):
self.session = None
self.exit_stack = AsyncExitStack()
"""
Manages multiple connections to MCP servers and aggregates tools.
"""
def __init__(
self,
server_parameters: Union[
StdioServerParameters,
Dict[str, Any],
List[Union[StdioServerParameters, Dict[str, Any]]]
],
logger: Optional[logging.Logger] = None,
):
self.logger = logger or logging.getLogger(__name__)

# Simplified callback system
self.callbacks: List[callable] = []

self.logger.debug("MCPClient initialized")
# Normalize to list
if not isinstance(server_parameters, list):
self.server_parameters = [server_parameters]
else:
self.server_parameters = server_parameters

self.exit_stack = AsyncExitStack()
self.sessions: List[ClientSession] = []
self.tools: List[Dict[str, Any]] = []
self._tool_session: Dict[str, ClientSession] = {}
self.callbacks: List[Callable] = []

def add_callback(self, callback: callable) -> None:
def add_callback(self, callback: Callable) -> None:
"""
Add a callback function to the client.

Args:
callback: A function that accepts (event_name, client, **kwargs)
Add a callback: async or sync func(event_name, client, **kwargs)
"""
self.callbacks.append(callback)

async def _run_callbacks(self, event_name: str, **kwargs) -> None:
"""
Run all registered callbacks for an event.

Args:
event_name: The name of the event
**kwargs: Additional data for the event
"""
for callback in self.callbacks:
try:
logger.debug(f"Running callback: {callback}")
self.logger.debug(f"Callback: {callback}, event: {event_name}")
if asyncio.iscoroutinefunction(callback):
logger.debug(f"Callback is a coroutine function")
await callback(event_name, self, **kwargs)
else:
# Check if the callback is a class with an async __call__ method
if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__):
logger.debug(f"Callback is a class with an async __call__ method")
await callback(event_name, self, **kwargs)
else:
logger.debug(f"Callback is a regular function")
callback(event_name, self, **kwargs)
except Exception as e:
logger.error(f"Error in callback for {event_name}: {str(e)}")
self.logger.error(f"Error in callback {event_name}: {e}")

async def connect(self, command: str, args: list[str]):
async def connect(self) -> None:
"""
Launches the MCP server subprocess and initializes the client session.
:param command: e.g. "python" or "node"
:param args: list of args to pass, e.g. ["my_server.py"] or ["build/index.js"]
Connect to all MCP servers and gather tools.
"""
# Prepare stdio transport parameters
params = StdioServerParameters(command=command, args=args)
# Open the stdio client transport
self.stdio, self.sock_write = await self.exit_stack.enter_async_context(
stdio_client(params)
)
# Create and initialize the MCP client session
self.session = await self.exit_stack.enter_async_context(
ClientSession(self.stdio, self.sock_write)
)
await self.session.initialize()
for params in self.server_parameters:
# Build parameters object
if isinstance(params, dict):
params_obj = StdioServerParameters(**params)
elif isinstance(params, StdioServerParameters):
params_obj = params
else:
raise ValueError(f"Invalid server param type: {type(params)}")

# Enter stdio transport
stdio, sock_write = await self.exit_stack.enter_async_context(
stdio_client(params_obj)
)
# Enter client session
session = await self.exit_stack.enter_async_context(
ClientSession(stdio, sock_write)
)
await session.initialize()
self.sessions.append(session)

async def list_tools(self):
resp = await self.session.list_tools()
# Aggregate tools and map to sessions
for session in self.sessions:
resp = await session.list_tools()
for tool in resp.tools:
self.tools.append({
'name': tool.name,
'description': tool.description,
'schema': getattr(tool, 'schema', None),
})
self._tool_session[tool.name] = session

async def list_tools(self) -> None:
"""
Print all available tools.
"""
print("Available tools:")
for tool in resp.tools:
print(f" • {tool.name}: {tool.description}")
for t in self.tools:
print(f" • {t['name']}: {t['description']}")

async def call_tool(self, name: str, arguments: dict):
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any:
"""
Invokes a named tool and returns its raw content list.
Invoke named tool on its session.
"""
# Notify tool start
await self._run_callbacks("tool_start", tool_name=name, arguments=arguments)

if name not in self._tool_session:
raise ValueError(f"Tool '{name}' not found.")
session = self._tool_session[name]
try:
resp = await self.session.call_tool(name, arguments)

# Notify tool end
await self._run_callbacks("tool_end", tool_name=name, arguments=arguments,
result=resp.content, success=True)

resp = await session.call_tool(name, arguments)
await self._run_callbacks(
"tool_end",
tool_name=name,
arguments=arguments,
result=resp.content,
success=True,
)
return resp.content
except Exception as e:
# Notify tool end with error
await self._run_callbacks("tool_end", tool_name=name, arguments=arguments,
error=str(e), success=False)
await self._run_callbacks(
"tool_end",
tool_name=name,
arguments=arguments,
error=str(e),
success=False,
)
raise

async def close(self):
"""Clean up subprocess and streams."""
if self.exit_stack:
try:
await self.exit_stack.aclose()
except (RuntimeError, asyncio.CancelledError) as e:
# Log the error but don't re-raise it
self.logger.error(f"Error during client cleanup: {e}")
finally:
# Always reset these regardless of success or failure
self.session = None
self.exit_stack = AsyncExitStack()

async def run_example():
"""Example usage of MCPClient with proper logging."""
import sys
from tinyagent.hooks.logging_manager import LoggingManager

# Create and configure logging manager
log_manager = LoggingManager(default_level=logging.INFO)
log_manager.set_levels({
'tinyagent.mcp_client': logging.DEBUG, # Debug for this module
'tinyagent.tiny_agent': logging.INFO,
})

# Configure a console handler
console_handler = logging.StreamHandler(sys.stdout)
log_manager.configure_handler(
console_handler,
format_string='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.DEBUG
)

# Get module-specific logger
mcp_logger = log_manager.get_logger('tinyagent.mcp_client')

mcp_logger.debug("Starting MCPClient example")

# Create client with our logger
client = MCPClient(logger=mcp_logger)

try:
# Connect to a simple echo server
await client.connect("python", ["-m", "mcp.examples.echo_server"])

# List available tools
await client.list_tools()

# Call the echo tool
result = await client.call_tool("echo", {"message": "Hello, MCP!"})
mcp_logger.info(f"Echo result: {result}")

finally:
# Clean up
await client.close()
mcp_logger.debug("Example completed")
async def close(self) -> None:
"""
Cleanup all sessions and transports.
"""
try:
await self.exit_stack.aclose()
except Exception as e:
self.logger.error(f"Cleanup error: {e}")
finally:
self.sessions = []
self.tools = []
self._tool_session = {}
self.exit_stack = AsyncExitStack()