diff --git a/python/semantic_kernel/agents/open_ai/responses_agent_thread_actions.py b/python/semantic_kernel/agents/open_ai/responses_agent_thread_actions.py index d2375a9f3ce4..57783f675251 100644 --- a/python/semantic_kernel/agents/open_ai/responses_agent_thread_actions.py +++ b/python/semantic_kernel/agents/open_ai/responses_agent_thread_actions.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import json import logging import uuid from collections.abc import AsyncIterable, Awaitable, Callable, Sequence @@ -17,6 +18,7 @@ from openai.types.responses.response_function_call_arguments_delta_event import ResponseFunctionCallArgumentsDeltaEvent from openai.types.responses.response_input_text import ResponseInputText from openai.types.responses.response_item import ResponseItem +from openai.types.responses.response_mcp_call_arguments_done_event import ResponseMcpCallArgumentsDoneEvent from openai.types.responses.response_output_item import ResponseOutputItem from openai.types.responses.response_output_item_added_event import ResponseOutputItemAddedEvent from openai.types.responses.response_output_item_done_event import ResponseOutputItemDoneEvent @@ -233,13 +235,26 @@ async def invoke( ) yield False, reasoning_message + # Extract MCP tool call + result contents (provider auto-executed) + mcp_call_contents, mcp_result_contents = cls._get_mcp_contents_from_output( + response.output, base_metadata=metadata + ) + # Check if tool calls are required function_calls = cls._get_tool_calls_from_output(response.output) # type: ignore - if (fc_count := len(function_calls)) == 0: - yield True, cls._create_response_message_content(response, agent.ai_model_id, agent.name) # type: ignore - break + fc_count: int = len(function_calls) response_message = cls._create_response_message_content(response, agent.ai_model_id, agent.name) # type: ignore + # Append MCP call + result contents so user/history can see them, but they are NOT scheduled for invocation + if mcp_call_contents or mcp_result_contents: + response_message.items.extend(mcp_call_contents) + response_message.items.extend(mcp_result_contents) + + if fc_count == 0: + yield True, response_message + break + + # Yield response with function calls (not final yet) yield False, response_message # Update both histories so subsequent requests include tool call context chat_history.add_message(message=response_message) @@ -283,7 +298,14 @@ async def invoke( response_options=response_options, ) assert isinstance(response, Response) # nosec - yield True, cls._create_response_message_content(response, agent.ai_model_id, agent.name) + mcp_call_contents, mcp_result_contents = cls._get_mcp_contents_from_output( + response.output, base_metadata=metadata + ) + final_msg = cls._create_response_message_content(response, agent.ai_model_id, agent.name) + if mcp_call_contents or mcp_result_contents: + final_msg.items.extend(mcp_call_contents) + final_msg.items.extend(mcp_result_contents) + yield True, final_msg @classmethod async def invoke_stream( @@ -396,7 +418,9 @@ async def invoke_stream( ) all_messages: list[StreamingChatMessageContent] = [] - function_call_returned = False + function_call_returned: bool = False + # Track MCP tool call information by item_id + mcp_tool_calls: dict[str, dict[str, Any]] = {} async with response as response_stream: async for event in response_stream: @@ -409,10 +433,28 @@ async def invoke_stream( # Ensure subsequent requests link to this response context previous_response_id = event.response.id case ResponseOutputItemAddedEvent(): - function_calls = cls._get_tool_calls_from_output([event.item]) # type: ignore + # MCP tool call tracking + if cls._is_mcp_tool_call(getattr(event, "item", None)): + cls._register_mcp_tool_call( + mcp_tool_calls, + event.item, + output_index=getattr(event, "output_index", None), + ) + # Skip adding MCP calls (auto-executed by provider) + continue + + function_calls: list[FunctionCallContent] = cls._get_tool_calls_from_output([event.item]) # type: ignore if function_calls: function_call_returned = True - msg = cls._build_streaming_msg( + # Add event_type metadata to function calls + for func_call in function_calls: + func_call.metadata = cls._create_event_metadata( + metadata, + event, + sequence_number=getattr(event, "sequence_number", None), + output_index=getattr(event, "output_index", None), + ) + msg: StreamingChatMessageContent = cls._build_streaming_msg( agent=agent, metadata=metadata, event=event, @@ -423,8 +465,14 @@ async def invoke_stream( case ResponseFunctionCallArgumentsDeltaEvent(): function_call = FunctionCallContent( id=event.item_id, - index=getattr(event, "index", None), + index=None, arguments=event.delta, + metadata=cls._create_event_metadata( + metadata, + event, + sequence_number=getattr(event, "sequence_number", None), + output_index=getattr(event, "output_index", None), + ), ) msg = cls._build_streaming_msg( agent=agent, @@ -434,10 +482,33 @@ async def invoke_stream( choice_index=request_index, ) all_messages.append(msg) + case ResponseMcpCallArgumentsDoneEvent(): + if on_intermediate_message: + mcp_args_done = cls._build_mcp_arguments_done_content( + item_id=event.item_id, + arguments=event.arguments, + state=mcp_tool_calls, + event=event, + base_metadata=metadata, + ) + mcp_msg = cls._build_streaming_msg( + agent=agent, + metadata=metadata, + event=event, + items=[mcp_args_done], + choice_index=request_index, + ) + await on_intermediate_message(mcp_msg) case ResponseTextDeltaEvent(): text_content = StreamingTextContent( text=event.delta, choice_index=request_index, + metadata=cls._create_event_metadata( + metadata, + event, + sequence_number=getattr(event, "sequence_number", None), + output_index=getattr(event, "output_index", None), + ), ) msg = cls._build_streaming_msg( agent=agent, @@ -452,12 +523,14 @@ async def invoke_stream( reasoning_content = StreamingReasoningContent( text=event.delta, choice_index=request_index, - metadata={ - "item_id": event.item_id, - "output_index": event.output_index, - "sequence_number": event.sequence_number, - "content_index": event.content_index, - }, + metadata=cls._create_event_metadata( + metadata, + event, + item_id=getattr(event, "item_id", None), + output_index=getattr(event, "output_index", None), + sequence_number=getattr(event, "sequence_number", None), + content_index=getattr(event, "content_index", None), + ), ) reasoning_msg = cls._build_streaming_msg( agent=agent, @@ -471,12 +544,14 @@ async def invoke_stream( if on_intermediate_message: final_reasoning_content = ReasoningContent( text=event.text, - metadata={ - "item_id": event.item_id, - "output_index": event.output_index, - "sequence_number": event.sequence_number, - "content_index": event.content_index, - }, + metadata=cls._create_event_metadata( + metadata, + event, + item_id=getattr(event, "item_id", None), + output_index=getattr(event, "output_index", None), + sequence_number=getattr(event, "sequence_number", None), + content_index=getattr(event, "content_index", None), + ), ) reasoning_msg = cls._build_streaming_msg( agent=agent, @@ -491,13 +566,15 @@ async def invoke_stream( reasoning_content = StreamingReasoningContent( text=event.delta, choice_index=request_index, - metadata={ - "item_id": event.item_id, - "output_index": event.output_index, - "sequence_number": event.sequence_number, - "summary_index": event.summary_index, - "is_summary": True, - }, + metadata=cls._create_event_metadata( + metadata, + event, + item_id=getattr(event, "item_id", None), + output_index=getattr(event, "output_index", None), + sequence_number=getattr(event, "sequence_number", None), + summary_index=getattr(event, "summary_index", None), + is_summary=True, + ), ) reasoning_msg = cls._build_streaming_msg( agent=agent, @@ -511,13 +588,15 @@ async def invoke_stream( if on_intermediate_message: final_reasoning_summary_content = ReasoningContent( text=event.text, - metadata={ - "item_id": event.item_id, - "output_index": event.output_index, - "sequence_number": event.sequence_number, - "summary_index": event.summary_index, - "is_summary": True, - }, + metadata=cls._create_event_metadata( + metadata, + event, + item_id=getattr(event, "item_id", None), + output_index=getattr(event, "output_index", None), + sequence_number=getattr(event, "sequence_number", None), + summary_index=getattr(event, "summary_index", None), + is_summary=True, + ), ) reasoning_msg = cls._build_streaming_msg( agent=agent, @@ -528,6 +607,27 @@ async def invoke_stream( ) await on_intermediate_message(reasoning_msg) case ResponseOutputItemDoneEvent(): + if ( + on_intermediate_message + and cls._is_mcp_tool_call(getattr(event, "item", None)) + and hasattr(event.item, "output") + ): + mcp_result = cls._build_mcp_result_content( + item=event.item, + state=mcp_tool_calls, + event=event, + base_metadata=metadata, + ) + if mcp_result: + mcp_result_msg = cls._build_streaming_msg( + agent=agent, + metadata=metadata, + event=event, + items=[mcp_result], + choice_index=request_index, + ) + await on_intermediate_message(mcp_result_msg) + msg = cls._create_output_item_done(agent, event.item) # type: ignore if output_messages is not None: output_messages.append(msg) @@ -543,7 +643,6 @@ async def invoke_stream( full_completion: StreamingChatMessageContent = reduce(lambda x, y: x + y, all_messages) if output_messages is not None: - # Append the content with function call content to the msgs used for the callback output_messages.append(full_completion) function_calls = [item for item in full_completion.items if isinstance(item, FunctionCallContent)] chat_history.add_message(message=full_completion) @@ -553,9 +652,6 @@ async def invoke_stream( fc_count = len(function_calls) logger.info(f"processing {fc_count} tool calls in parallel.") - # This function either updates the chat history with the function call results - # or returns the context, with terminate set to True in which case the loop will - # break and the function calls are returned. results = await asyncio.gather( *[ kernel.invoke_function_call( @@ -572,9 +668,6 @@ async def invoke_stream( ], ) - # Merge and yield the function results, regardless of the termination status - # Include the ai_model_id so we can later add two streaming messages together - # Some settings may not have an ai_model_id, so we need to check for it function_result_messages = cls._merge_streaming_function_results( messages=override_history.messages[-len(results) :], # type: ignore name=agent.name, @@ -693,6 +786,20 @@ async def get_messages( if not responses.has_more: break + @classmethod + def _create_event_metadata( + cls: type[_T], + base_metadata: dict[str, str] | None, + event: ResponseStreamEvent, + **additional_fields: Any, + ) -> dict[str, Any]: + """Create metadata for streaming events with consistent structure.""" + return { + **(base_metadata or {}), + "event_type": event.type, + **additional_fields, + } + @classmethod def _build_streaming_msg( cls: type[_T], @@ -1216,4 +1323,165 @@ def _get_tools( return tools + @classmethod + def _is_mcp_tool_call(cls: type[_T], item: Any) -> bool: + """Return True if the given item represents an MCP tool call (type == 'mcp_call').""" + return bool(getattr(item, "type", None) == "mcp_call") + + @classmethod + def _register_mcp_tool_call( + cls: type[_T], + state: dict[str, dict[str, Any]], + item: Any, + *, + output_index: int | None, + ) -> None: + """Register basic MCP tool call metadata into the state.""" + item_id = getattr(item, "id", None) + if not item_id: + return + state[item_id] = { + "name": getattr(item, "name", "") or "", + "server_label": getattr(item, "server_label", "") or "", + "output_index": output_index, + } + + @classmethod + def _build_mcp_arguments_done_content( + cls: type[_T], + *, + item_id: str, + arguments: str | dict | None, + state: dict[str, dict[str, Any]], + event: ResponseStreamEvent, + base_metadata: dict[str, str] | None, + ) -> FunctionCallContent: + """Convert an MCP 'arguments done' event into a FunctionCallContent.""" + tool_meta = state.get(item_id, {}) + tool_name = tool_meta.get("name", "") + server_label = tool_meta.get("server_label", "") + return FunctionCallContent( + id=item_id, + index=getattr(event, "output_index", None), + arguments=arguments, + name=tool_name, + function_name=tool_name, + plugin_name=server_label, + metadata=cls._create_event_metadata( + base_metadata, + event, + sequence_number=getattr(event, "sequence_number", None), + output_index=getattr(event, "output_index", None), + ), + ) + + @classmethod + def _build_mcp_result_content( + cls: type[_T], + *, + item: Any, + state: dict[str, dict[str, Any]], + event: ResponseStreamEvent, + base_metadata: dict[str, str] | None, + ) -> FunctionResultContent | None: + """Create a FunctionResultContent from a completed MCP output item.""" + if not cls._is_mcp_tool_call(item): + return None + + tool_meta = state.get(getattr(item, "id", ""), {}) + tool_name = tool_meta.get("name", "") + server_label = tool_meta.get("server_label", "") + + raw_args = getattr(item, "arguments", None) + used_arguments: dict[str, Any] = {} + if raw_args: + if isinstance(raw_args, dict): + used_arguments = raw_args + else: + try: + used_arguments = json.loads(raw_args) + except (json.JSONDecodeError, TypeError): + used_arguments = {} + + metadata = cls._create_event_metadata( + base_metadata, + event, + used_arguments=used_arguments, + arguments=used_arguments, # 表示上 arguments も展開 + sequence_number=getattr(event, "sequence_number", None), + output_index=getattr(event, "output_index", None), + ) + return FunctionResultContent( + id=getattr(item, "id", None), + name=tool_name, + function_name=tool_name, + plugin_name=server_label, + result=getattr(item, "output", None), + metadata=metadata, + ) + + @classmethod + def _get_mcp_contents_from_output( + cls: type[_T], + output: list[Any], + *, + base_metadata: dict[str, str] | None = None, + ) -> tuple[list[FunctionCallContent], list[FunctionResultContent]]: + """Extract MCP call and result contents from a non-streaming response output.""" + call_contents: list[FunctionCallContent] = [] + result_contents: list[FunctionResultContent] = [] + for item in output or []: + if not cls._is_mcp_tool_call(item): + continue + item_id = getattr(item, "id", None) + tool_name = getattr(item, "name", "") or "" + server_label = getattr(item, "server_label", "") or "" + raw_args = getattr(item, "arguments", None) + parsed_args: dict[str, Any] | str | None = None + used_arguments: dict[str, Any] = {} + if raw_args: + if isinstance(raw_args, dict): + parsed_args = raw_args + used_arguments = raw_args + else: + try: + used_arguments = json.loads(raw_args) + parsed_args = used_arguments + except (json.JSONDecodeError, TypeError): + parsed_args = raw_args # leave as-is if not JSON + call_metadata: dict[str, Any] = { + **(base_metadata or {}), + "event_type": "mcp_call", + "arguments": parsed_args, + } + call_contents.append( + FunctionCallContent( + id=item_id, + index=getattr(item, "index", None), + arguments=parsed_args, + name=tool_name, + function_name=tool_name, + plugin_name=server_label, + metadata=call_metadata, + ) + ) + if hasattr(item, "output"): + result_metadata: dict[str, Any] = { + **(base_metadata or {}), + "event_type": "mcp_call_result", + "arguments": used_arguments, + "used_arguments": used_arguments, + } + result_contents.append( + FunctionResultContent( + id=item_id, + name=tool_name, + function_name=tool_name, + plugin_name=server_label, + result=getattr(item, "output", None), + metadata=result_metadata, + ) + ) + return call_contents, result_contents + # endregion