Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 173 additions & 66 deletions tinyagent/mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,95 @@
import asyncio
import json
import logging
from typing import Dict, List, Optional, Any, Tuple, Callable
#!/usr/bin/env python
# coding=utf-8

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Keep your MCPClient implementation unchanged
from __future__ import annotations

import warnings
import asyncio
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, List, Dict
from contextlib import AsyncExitStack

# MCP core imports
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

__all__ = ["MCPClient"]

if TYPE_CHECKING:
from mcp import StdioServerParameters

# Set up logging
logger = logging.getLogger(__name__)

class MCPClient:
def __init__(self, logger: Optional[logging.Logger] = None):
"""Manages the connection to an MCP server using per-instance context manager pattern.

This implementation adopts the per-instance context manager pattern from smolagents
to fix cross-talk and cancel-scope errors when multiple clients are connected concurrently.

Note: tools can only be accessed after the connection has been started with the
`connect()` method. If you don't use the context manager we strongly encourage
to use "try ... finally" to ensure the connection is cleaned up.

Args:
server_parameters (StdioServerParameters | dict[str, Any] | None):
Configuration parameters to connect to the MCP server.

- An instance of `mcp.StdioServerParameters` for connecting a Stdio MCP server
via standard input/output using a subprocess.

- A `dict` with command and args for stdio connection.

logger (Optional[logging.Logger]): Custom logger instance.

Example:
```python
# fully managed context manager
async with MCPClient() as client:
await client.connect("python", ["-m", "mcp.examples.echo_server"])
tools = await client.list_tools()

# manually manage the connection:
try:
mcp_client = MCPClient()
await mcp_client.connect("python", ["-m", "mcp.examples.echo_server"])
tools = await mcp_client.list_tools()

# use your tools here.
finally:
await mcp_client.disconnect()
```
"""

def __init__(
self,
server_parameters: "StdioServerParameters" | dict[str, Any] | None = None,
logger: Optional[logging.Logger] = None,
):
self.server_parameters = server_parameters
self.session = None
self.exit_stack = AsyncExitStack()
self.logger = logger or logging.getLogger(__name__)

# Simplified callback system
self.callbacks: List[callable] = []

self.logger.debug("MCPClient initialized")
self.logger.debug("MCPClient initialized with per-instance context manager")

def add_callback(self, callback: callable) -> None:
"""
Expand All @@ -44,49 +110,96 @@ async def _run_callbacks(self, event_name: str, **kwargs) -> None:
"""
for callback in self.callbacks:
try:
logger.debug(f"Running callback: {callback}")
self.logger.debug(f"Running callback: {callback}")
if asyncio.iscoroutinefunction(callback):
logger.debug(f"Callback is a coroutine function")
self.logger.debug(f"Callback is a coroutine function")
await callback(event_name, self, **kwargs)
else:
# Check if the callback is a class with an async __call__ method
if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__):
logger.debug(f"Callback is a class with an async __call__ method")
self.logger.debug(f"Callback is a class with an async __call__ method")
await callback(event_name, self, **kwargs)
else:
logger.debug(f"Callback is a regular function")
self.logger.debug(f"Callback is a regular function")
callback(event_name, self, **kwargs)
except Exception as e:
logger.error(f"Error in callback for {event_name}: {str(e)}")
self.logger.error(f"Error in callback for {event_name}: {str(e)}")

async def connect(self, command: str, args: list[str]):
"""
Launches the MCP server subprocess and initializes the client session.
:param command: e.g. "python" or "node"
:param args: list of args to pass, e.g. ["my_server.py"] or ["build/index.js"]
"""
# Prepare stdio transport parameters
params = StdioServerParameters(command=command, args=args)
# Open the stdio client transport
self.stdio, self.sock_write = await self.exit_stack.enter_async_context(
stdio_client(params)
)
# Create and initialize the MCP client session
self.session = await self.exit_stack.enter_async_context(
ClientSession(self.stdio, self.sock_write)
)
await self.session.initialize()
async def connect(self, command: str = None, args: list[str] = None):
"""Connect to the MCP server and initialize the session."""
if command and args:
# Legacy support for direct command/args
params = StdioServerParameters(command=command, args=args)
elif self.server_parameters:
if isinstance(self.server_parameters, dict):
# Convert dict to StdioServerParameters
params = StdioServerParameters(
command=self.server_parameters.get('command'),
args=self.server_parameters.get('args', [])
)
else:
params = self.server_parameters
else:
raise ValueError("Either command/args or server_parameters must be provided")

try:
# Open the stdio client transport using per-instance exit stack
self.stdio, self.sock_write = await self.exit_stack.enter_async_context(
stdio_client(params)
)
# Create and initialize the MCP client session
self.session = await self.exit_stack.enter_async_context(
ClientSession(self.stdio, self.sock_write)
)
await self.session.initialize()
self.logger.debug("MCP client connected successfully")
except Exception as e:
self.logger.error(f"Failed to connect MCP client: {e}")
# Clean up on connection failure
await self._cleanup_exit_stack()
raise

async def disconnect(
self,
exc_type: type[BaseException] | None = None,
exc_value: BaseException | None = None,
exc_traceback: TracebackType | None = None,
):
"""Disconnect from the MCP server"""
await self._cleanup_exit_stack()

async def _cleanup_exit_stack(self):
"""Clean up the exit stack safely"""
if self.exit_stack:
try:
await self.exit_stack.aclose()
self.logger.debug("Exit stack closed successfully")
except Exception as e:
# Log the error but don't re-raise it to prevent cascade failures
self.logger.error(f"Error during exit stack cleanup: {e}")
finally:
# Always reset these regardless of success or failure
self.session = None
self.exit_stack = AsyncExitStack()

async def list_tools(self):
"""List available tools from the MCP server."""
if not self.session:
raise ValueError("Client not connected. Call connect() first.")

resp = await self.session.list_tools()
print("Available tools:")
self.logger.info("Available tools:")
for tool in resp.tools:
print(f" • {tool.name}: {tool.description}")
self.logger.info(f" {tool.name}: {tool.description}")
return resp.tools

async def call_tool(self, name: str, arguments: dict):
"""
Invokes a named tool and returns its raw content list.
"""
if not self.session:
raise ValueError("Client not connected. Call connect() first.")

# Notify tool start
await self._run_callbacks("tool_start", tool_name=name, arguments=arguments)

Expand All @@ -105,58 +218,52 @@ async def call_tool(self, name: str, arguments: dict):
raise

async def close(self):
"""Clean up subprocess and streams."""
if self.exit_stack:
try:
await self.exit_stack.aclose()
except (RuntimeError, asyncio.CancelledError) as e:
# Log the error but don't re-raise it
self.logger.error(f"Error during client cleanup: {e}")
finally:
# Always reset these regardless of success or failure
self.session = None
self.exit_stack = AsyncExitStack()
"""Clean up subprocess and streams. Alias for disconnect()."""
await self.disconnect()

async def __aenter__(self):
"""Connect to the MCP server and return the client directly."""
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
):
"""Disconnect from the MCP server."""
await self.disconnect(exc_type, exc_value, exc_traceback)

async def run_example():
"""Example usage of MCPClient with proper logging."""
import sys
from tinyagent.hooks.logging_manager import LoggingManager

# Create and configure logging manager
log_manager = LoggingManager(default_level=logging.INFO)
log_manager.set_levels({
'tinyagent.mcp_client': logging.DEBUG, # Debug for this module
'tinyagent.tiny_agent': logging.INFO,
})

# Configure a console handler
console_handler = logging.StreamHandler(sys.stdout)
log_manager.configure_handler(
console_handler,
format_string='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.DEBUG
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)

# Get module-specific logger
mcp_logger = log_manager.get_logger('tinyagent.mcp_client')
mcp_logger = logging.getLogger('tinyagent.mcp_client')
mcp_logger.setLevel(logging.DEBUG)

mcp_logger.debug("Starting MCPClient example")
mcp_logger.debug("Starting MCPClient example with per-instance context manager")

# Create client with our logger
client = MCPClient(logger=mcp_logger)

try:
async with MCPClient(logger=mcp_logger) as client:
# Connect to a simple echo server
await client.connect("python", ["-m", "mcp.examples.echo_server"])

# List available tools
await client.list_tools()
tools = await client.list_tools()

# Call the echo tool
result = await client.call_tool("echo", {"message": "Hello, MCP!"})
mcp_logger.info(f"Echo result: {result}")

finally:
# Clean up
await client.close()
mcp_logger.debug("Example completed")
mcp_logger.debug("Example completed")

if __name__ == "__main__":
asyncio.run(run_example())