Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 51 additions & 41 deletions tinyagent/mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,67 @@
++ b/tinyagent/mcp_client.py
import asyncio
import json
import logging
from typing import Optional, List, Coroutine, Any
from typing import Dict, List, Optional, Any, Tuple, Callable

# Keep your MCPClient implementation unchanged
import asyncio
from contextlib import AsyncExitStack

# MCP core imports
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

# Set up logging
logger = logging.getLogger(__name__)

class MCPClient:
def __init__(self, logger: Optional[logging.Logger] = None):

# We'll hold each context manager separately rather than in one stack:
self._stdio_ctx = None
self._session_ctx = None
self.stdio = None
self.sock_write = None
self.session = None
self.session = None
self.exit_stack = AsyncExitStack()
self.logger = logger or logging.getLogger(__name__)

# Simplified callback system
self.callbacks: List[callable] = []
self.callbacks: List[Callable[..., Coroutine[Any,Any,Any]]] = []

self.logger.debug("MCPClient initialized")
for callback in self.callbacks:
try:
params = StdioServerParameters(command=command, args=args)

def add_callback(self, callback: callable) -> None:
"""
Add a callback function to the client.

Args:
callback: A function that accepts (event_name, client, **kwargs)
"""
self.callbacks.append(callback)

async def _run_callbacks(self, event_name: str, **kwargs) -> None:
"""
Run all registered callbacks for an event.

Args:
event_name: The name of the event
**kwargs: Additional data for the event
"""
# 1) enter stdio_client context
self._stdio_ctx = stdio_client(params)
try:
self.stdio, self.sock_write = await self._stdio_ctx.__aenter__()
except Exception as e:
self.logger.error(f"Failed to start stdio_client: {e}")
raise

# 2) enter ClientSession context
self._session_ctx = ClientSession(self.stdio, self.sock_write)
try:
self.session = await self._session_ctx.__aenter__()
await self.session.initialize()
except Exception as e:
self.logger.error(f"Failed to initialize MCP session: {e}")
# make sure we unwind the stdio context if session init fails
await self._stdio_ctx.__aexit__(None, None, None)
raise
for callback in self.callbacks:
try:
logger.debug(f"Running callback: {callback}")
if asyncio.iscoroutinefunction(callback):
logger.debug(f"Callback is a coroutine function")
await callback(event_name, self, **kwargs)
else:
# Check if the callback is a class with an async __call__ method
if hasattr(callback, '__call__') and asyncio.iscoroutinefunction(callback.__call__):
logger.debug(f"Callback is a class with an async __call__ method")
# 1) teardown session
if self._session_ctx is not None:
try:
await self._session_ctx.__aexit__(None, None, None)
except Exception as e:
self.logger.error(f"Error closing MCP session: {e}")
finally:
self.session = None
self._session_ctx = None

# 2) teardown stdio
if self._stdio_ctx is not None:
try:
await self._stdio_ctx.__aexit__(None, None, None)
except Exception as e:
self.logger.error(f"Error closing stdio client: {e}")
finally:
self.stdio = None
self.sock_write = None
self._stdio_ctx = None
await callback(event_name, self, **kwargs)
else:
logger.debug(f"Callback is a regular function")
Expand Down