Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 75 additions & 44 deletions python/semantic_kernel/connectors/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,15 @@ async def __aexit__(

async def connect(self) -> None:
"""Connect to the MCP server."""
ready_event = asyncio.Event()
loop = asyncio.get_running_loop()
ready_future: asyncio.Future[None] = loop.create_future()

try:
self._current_task = asyncio.create_task(self._inner_connect(ready_event))
await ready_event.wait()
self._current_task = asyncio.create_task(self._inner_connect(ready_future))
Comment thread
aryanyk marked this conversation as resolved.
await ready_future
except KernelPluginInvalidConfigurationError:
ready_event.clear()
raise
except Exception as ex:
ready_event.clear()
await self.close()
raise FunctionExecutionException("Failed to enter context manager.") from ex

Expand All @@ -293,16 +293,24 @@ async def close(self) -> None:
self._current_task = None
self.session = None

async def _inner_connect(self, ready_event: asyncio.Event) -> None:
async def _inner_connect(self, ready_future: asyncio.Future) -> None:
if not self.session:
try:
transport = await self._exit_stack.enter_async_context(self.get_mcp_client())
except Exception as ex:
await self._exit_stack.aclose()
ready_event.set()
raise KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
) from ex
if not ready_future.done():
exc = KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
)
Comment thread
aryanyk marked this conversation as resolved.
exc.__cause__ = ex
exc.__suppress_context__ = True
ready_future.set_exception(exc)
try:
await self._exit_stack.aclose()
except Exception:
logger.warning("Failed to close exit stack during error handling")
return

try:
Comment thread
aryanyk marked this conversation as resolved.
session = await self._exit_stack.enter_async_context(
ClientSession(
Expand All @@ -315,36 +323,60 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None:
)
)
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to create a session. Please check your configuration."
) from ex
if not ready_future.done():
exc = KernelPluginInvalidConfigurationError(
"Failed to connect to the MCP server. Please check your configuration."
)
exc.__cause__ = ex
exc.__suppress_context__ = True
ready_future.set_exception(exc)
try:
await self._exit_stack.aclose()
except Exception:
logger.warning("Failed to close exit stack during error handling")
return
try:
await session.initialize()
except Exception as ex:
await self._exit_stack.aclose()
raise KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
) from ex
if not ready_future.done():
exc = KernelPluginInvalidConfigurationError(
"Failed to initialize session. Please check your configuration."
)
exc.__cause__ = ex
exc.__suppress_context__ = True
ready_future.set_exception(exc)
try:
await self._exit_stack.aclose()
except Exception:
logger.warning("Failed to close exit stack during error handling")
return
self.session = session
elif self.session._request_id == 0:
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
logger.debug("Connected to MCP server: %s", self.session)
if self.load_tools_flag:
await self.load_tools()
if self.load_prompts_flag:
await self.load_prompts()

if logger.level != logging.NOTSET:
try:
if self.load_tools_flag:
await self.load_tools()
if self.load_prompts_flag:
await self.load_prompts()
if logger.level != logging.NOTSET:
try:
await self.session.set_logging_level(
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
)
except Exception:
logger.warning("Failed to set log level to %s", logger.level)
if not ready_future.done():
ready_future.set_result(None)
except Exception as ex:
if not ready_future.done():
ready_future.set_exception(ex)
try:
await self.session.set_logging_level(
next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
)
await self._exit_stack.aclose()
except Exception:
logger.warning("Failed to set log level to %s", logger.level)
# Setting up is complete, will now signal the main loop that we are ready
ready_event.set()
logger.warning("Failed to close exit stack during error handling")
return
# Create a stop event to signal the exit stack to close
self._stop_event = asyncio.Event()
await self._stop_event.wait()
Expand Down Expand Up @@ -460,17 +492,20 @@ async def message_handler(
if isinstance(message, types.ServerNotification):
match message.root.method:
case "notifications/tools/list_changed":
await self.load_tools()
try:
await self.load_tools()
except Exception as ex:
logger.warning("Failed to reload tools on notification: %s", ex)
case "notifications/prompts/list_changed":
await self.load_prompts()
try:
await self.load_prompts()
except Exception as ex:
logger.warning("Failed to reload prompts on notification: %s", ex)

async def load_prompts(self):
"""Load prompts from the MCP server."""
try:
prompt_list = await self.session.list_prompts()
except Exception:
prompt_list = None
for prompt in prompt_list.prompts if prompt_list else []:
prompt_list = await self.session.list_prompts()
for prompt in prompt_list.prompts:
local_name = _normalize_mcp_name(prompt.name)
func = kernel_function(name=local_name, description=prompt.description)(
partial(self.get_prompt, prompt.name)
Expand All @@ -480,12 +515,8 @@ async def load_prompts(self):

async def load_tools(self):
"""Load tools from the MCP server."""
try:
tool_list = await self.session.list_tools()
except Exception:
tool_list = None
# Create methods with the kernel_function decorator for each tool
for tool in tool_list.tools if tool_list else []:
tool_list = await self.session.list_tools()
for tool in tool_list.tools:
local_name = _normalize_mcp_name(tool.name)
func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name))
func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool)
Expand Down
Loading