diff --git a/python/semantic_kernel/connectors/mcp.py b/python/semantic_kernel/connectors/mcp.py index 5f31886b28fa..77249fb0c85c 100644 --- a/python/semantic_kernel/connectors/mcp.py +++ b/python/semantic_kernel/connectors/mcp.py @@ -269,15 +269,15 @@ async def __aexit__( async def connect(self) -> None: """Connect to the MCP server.""" - ready_event = asyncio.Event() + loop = asyncio.get_running_loop() + ready_future: asyncio.Future[None] = loop.create_future() + try: - self._current_task = asyncio.create_task(self._inner_connect(ready_event)) - await ready_event.wait() + self._current_task = asyncio.create_task(self._inner_connect(ready_future)) + await ready_future except KernelPluginInvalidConfigurationError: - ready_event.clear() raise except Exception as ex: - ready_event.clear() await self.close() raise FunctionExecutionException("Failed to enter context manager.") from ex @@ -293,16 +293,24 @@ async def close(self) -> None: self._current_task = None self.session = None - async def _inner_connect(self, ready_event: asyncio.Event) -> None: + async def _inner_connect(self, ready_future: asyncio.Future) -> None: if not self.session: try: transport = await self._exit_stack.enter_async_context(self.get_mcp_client()) except Exception as ex: - await self._exit_stack.aclose() - ready_event.set() - raise KernelPluginInvalidConfigurationError( - "Failed to connect to the MCP server. Please check your configuration." - ) from ex + if not ready_future.done(): + exc = KernelPluginInvalidConfigurationError( + "Failed to connect to the MCP server. Please check your configuration." + ) + exc.__cause__ = ex + exc.__suppress_context__ = True + ready_future.set_exception(exc) + try: + await self._exit_stack.aclose() + except Exception: + logger.warning("Failed to close exit stack during error handling") + return + try: session = await self._exit_stack.enter_async_context( ClientSession( @@ -315,36 +323,60 @@ async def _inner_connect(self, ready_event: asyncio.Event) -> None: ) ) except Exception as ex: - await self._exit_stack.aclose() - raise KernelPluginInvalidConfigurationError( - "Failed to create a session. Please check your configuration." - ) from ex + if not ready_future.done(): + exc = KernelPluginInvalidConfigurationError( + "Failed to connect to the MCP server. Please check your configuration." + ) + exc.__cause__ = ex + exc.__suppress_context__ = True + ready_future.set_exception(exc) + try: + await self._exit_stack.aclose() + except Exception: + logger.warning("Failed to close exit stack during error handling") + return try: await session.initialize() except Exception as ex: - await self._exit_stack.aclose() - raise KernelPluginInvalidConfigurationError( - "Failed to initialize session. Please check your configuration." - ) from ex + if not ready_future.done(): + exc = KernelPluginInvalidConfigurationError( + "Failed to initialize session. Please check your configuration." + ) + exc.__cause__ = ex + exc.__suppress_context__ = True + ready_future.set_exception(exc) + try: + await self._exit_stack.aclose() + except Exception: + logger.warning("Failed to close exit stack during error handling") + return self.session = session elif self.session._request_id == 0: # If the session is not initialized, we need to reinitialize it await self.session.initialize() logger.debug("Connected to MCP server: %s", self.session) - if self.load_tools_flag: - await self.load_tools() - if self.load_prompts_flag: - await self.load_prompts() - - if logger.level != logging.NOTSET: + try: + if self.load_tools_flag: + await self.load_tools() + if self.load_prompts_flag: + await self.load_prompts() + if logger.level != logging.NOTSET: + try: + await self.session.set_logging_level( + next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level) + ) + except Exception: + logger.warning("Failed to set log level to %s", logger.level) + if not ready_future.done(): + ready_future.set_result(None) + except Exception as ex: + if not ready_future.done(): + ready_future.set_exception(ex) try: - await self.session.set_logging_level( - next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level) - ) + await self._exit_stack.aclose() except Exception: - logger.warning("Failed to set log level to %s", logger.level) - # Setting up is complete, will now signal the main loop that we are ready - ready_event.set() + logger.warning("Failed to close exit stack during error handling") + return # Create a stop event to signal the exit stack to close self._stop_event = asyncio.Event() await self._stop_event.wait() @@ -460,17 +492,20 @@ async def message_handler( if isinstance(message, types.ServerNotification): match message.root.method: case "notifications/tools/list_changed": - await self.load_tools() + try: + await self.load_tools() + except Exception as ex: + logger.warning("Failed to reload tools on notification: %s", ex) case "notifications/prompts/list_changed": - await self.load_prompts() + try: + await self.load_prompts() + except Exception as ex: + logger.warning("Failed to reload prompts on notification: %s", ex) async def load_prompts(self): """Load prompts from the MCP server.""" - try: - prompt_list = await self.session.list_prompts() - except Exception: - prompt_list = None - for prompt in prompt_list.prompts if prompt_list else []: + prompt_list = await self.session.list_prompts() + for prompt in prompt_list.prompts: local_name = _normalize_mcp_name(prompt.name) func = kernel_function(name=local_name, description=prompt.description)( partial(self.get_prompt, prompt.name) @@ -480,12 +515,8 @@ async def load_prompts(self): async def load_tools(self): """Load tools from the MCP server.""" - try: - tool_list = await self.session.list_tools() - except Exception: - tool_list = None - # Create methods with the kernel_function decorator for each tool - for tool in tool_list.tools if tool_list else []: + tool_list = await self.session.list_tools() + for tool in tool_list.tools: local_name = _normalize_mcp_name(tool.name) func = kernel_function(name=local_name, description=tool.description)(partial(self.call_tool, tool.name)) func.__kernel_function_parameters__ = _get_parameter_dicts_from_mcp_tool(tool)