Skip to content

Commit d47910b

Browse files
Python: Improve function call invocation parameter consistency
Thread function_behavior parameter through all internal invoke_function_call callsites for consistent behavior. Add defensive logging when parameter is not provided. Update and add unit tests for parameter handling. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 2fb749e commit d47910b

8 files changed

Lines changed: 191 additions & 3 deletions

File tree

python/semantic_kernel/agents/azure_ai/agent_thread_actions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from azure.ai.projects.aio import AIProjectClient
8181

8282
from semantic_kernel.agents.azure_ai.azure_ai_agent import AzureAIAgent
83+
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
8384
from semantic_kernel.contents.chat_history import ChatHistory
8485
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
8586
from semantic_kernel.filters.auto_function_invocation.auto_function_invocation_context import (
@@ -1071,6 +1072,7 @@ async def _invoke_function_calls(
10711072
fccs: list["FunctionCallContent"],
10721073
chat_history: "ChatHistory",
10731074
arguments: KernelArguments,
1075+
function_behavior: "FunctionChoiceBehavior | None" = None,
10741076
) -> list["AutoFunctionInvocationContext | None"]:
10751077
"""Invoke the function calls."""
10761078
return await asyncio.gather(
@@ -1079,6 +1081,7 @@ async def _invoke_function_calls(
10791081
function_call=function_call,
10801082
chat_history=chat_history,
10811083
arguments=arguments,
1084+
function_behavior=function_behavior,
10821085
)
10831086
for function_call in fccs
10841087
],

python/semantic_kernel/agents/bedrock/bedrock_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ async def _handle_function_call_contents(
681681
chat_history=chat_history,
682682
arguments=self.arguments,
683683
function_call_count=len(function_call_contents),
684+
function_behavior=self.function_choice_behavior,
684685
)
685686
for function_call in function_call_contents
686687
],

python/semantic_kernel/agents/open_ai/assistant_thread_actions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from openai.types.beta.threads.run_create_params import AdditionalMessageAttachmentTool, TruncationStrategy
5252

5353
from semantic_kernel.agents.open_ai.openai_assistant_agent import OpenAIAssistantAgent
54+
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
5455
from semantic_kernel.contents.chat_history import ChatHistory
5556
from semantic_kernel.contents.chat_message_content import ChatMessageContent
5657
from semantic_kernel.contents.function_call_content import FunctionCallContent
@@ -658,6 +659,7 @@ async def _invoke_function_calls(
658659
fccs: list["FunctionCallContent"],
659660
chat_history: "ChatHistory",
660661
arguments: KernelArguments,
662+
function_behavior: "FunctionChoiceBehavior | None" = None,
661663
) -> list["AutoFunctionInvocationContext | None"]:
662664
"""Invoke the function calls."""
663665
return await asyncio.gather(
@@ -666,6 +668,7 @@ async def _invoke_function_calls(
666668
function_call=function_call,
667669
chat_history=chat_history,
668670
arguments=arguments,
671+
function_behavior=function_behavior,
669672
)
670673
for function_call in fccs
671674
],

python/semantic_kernel/agents/open_ai/responses_agent_thread_actions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,11 +1111,17 @@ def _collect_text_and_annotations(cls: type[_T], content_list: list[Any]) -> lis
11111111

11121112
@classmethod
11131113
async def _invoke_function_calls(
1114-
cls: type[_T], kernel: "Kernel", fccs: list["FunctionCallContent"], chat_history: "ChatHistory"
1114+
cls: type[_T],
1115+
kernel: "Kernel",
1116+
fccs: list["FunctionCallContent"],
1117+
chat_history: "ChatHistory",
1118+
function_behavior: "FunctionChoiceBehavior | None" = None,
11151119
) -> list[Any]:
11161120
"""Invoke the function calls."""
11171121
tasks = [
1118-
kernel.invoke_function_call(function_call=function_call, chat_history=chat_history)
1122+
kernel.invoke_function_call(
1123+
function_call=function_call, chat_history=chat_history, function_behavior=function_behavior
1124+
)
11191125
for function_call in fccs
11201126
]
11211127
return await asyncio.gather(*tasks)

python/semantic_kernel/connectors/ai/open_ai/services/_open_ai_realtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,8 @@ async def _parse_function_call_arguments_done(
481481

482482
# Step 4: Invoke the function call
483483
chat_history = ChatHistory()
484-
await self._kernel.invoke_function_call(item, chat_history)
484+
function_behavior = self._current_settings.function_choice_behavior if self._current_settings else None
485+
await self._kernel.invoke_function_call(item, chat_history, function_behavior=function_behavior)
485486
created_output: FunctionResultContent = chat_history.messages[-1].items[0] # type: ignore
486487
# Step 5: Create the function result event
487488
result = RealtimeFunctionResultEvent(

python/semantic_kernel/kernel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,13 @@ async def invoke_function_call(
347347
raise FunctionExecutionException(
348348
f"Only functions: {allowed_functions} are allowed, {function_call.name} is not allowed."
349349
)
350+
elif function_behavior is None:
351+
logger.warning(
352+
"invoke_function_call called without function_behavior. "
353+
"No allowlist validation will be performed for function '%s'. "
354+
"Pass a FunctionChoiceBehavior with filters to enable validation.",
355+
function_call.name,
356+
)
350357
function_to_call = self.get_function(function_call.plugin_name, function_call.function_name)
351358
except Exception as exc:
352359
logger.exception(f"The function `{function_call.name}` is not part of the provided functions: {exc}.")

python/tests/unit/connectors/ai/open_ai/services/test_openai_realtime.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,86 @@ async def test_parse_function_call_arguments_done_fail(OpenAIWebsocket, kernel):
556556
iter += 1
557557

558558

559+
async def test_parse_function_call_arguments_done_passes_function_behavior(OpenAIWebsocket, kernel):
560+
"""Verify that the realtime path passes function_choice_behavior to invoke_function_call."""
561+
func_result = "result"
562+
event = ResponseFunctionCallArgumentsDoneEvent(
563+
call_id="call_id",
564+
arguments='{"x": "' + func_result + '"}',
565+
event_id="event_id",
566+
output_index=0,
567+
item_id="item_id",
568+
name="plugin_name-function_name",
569+
response_id="response_id",
570+
type="response.function_call_arguments.done",
571+
)
572+
function_behavior = FunctionChoiceBehavior.Auto(filters={"included_plugins": ["plugin_name"]})
573+
OpenAIWebsocket._current_settings = OpenAIRealtimeExecutionSettings(
574+
instructions="instructions", ai_model_id="gpt-realtime"
575+
)
576+
OpenAIWebsocket._current_settings.function_choice_behavior = function_behavior
577+
OpenAIWebsocket._call_id_to_function_map["call_id"] = "plugin_name-function_name"
578+
func = kernel_function(name="function_name", description="function_description")(lambda x: x)
579+
kernel.add_function(plugin_name="plugin_name", function_name="function_name", function=func)
580+
OpenAIWebsocket._kernel = kernel
581+
582+
# Capture the kwargs passed to invoke_function_call
583+
captured_kwargs = {}
584+
original_invoke = Kernel.invoke_function_call
585+
586+
async def spy_invoke(self, *args, **kwargs):
587+
captured_kwargs.update(kwargs)
588+
return await original_invoke(self, *args, **kwargs)
589+
590+
with (
591+
patch.object(Kernel, "invoke_function_call", spy_invoke),
592+
patch.object(OpenAIWebsocket, "_send"),
593+
):
594+
async for _ in OpenAIWebsocket._parse_function_call_arguments_done(event):
595+
pass
596+
597+
assert "function_behavior" in captured_kwargs
598+
assert captured_kwargs["function_behavior"] is function_behavior
599+
600+
601+
async def test_parse_function_call_arguments_done_filters_block_unallowed(OpenAIWebsocket, kernel):
602+
"""Verify that the realtime path blocks a function not in the allowlist."""
603+
event = ResponseFunctionCallArgumentsDoneEvent(
604+
call_id="call_id",
605+
arguments='{"url": "http://169.254.169.254/"}',
606+
event_id="event_id",
607+
output_index=0,
608+
item_id="item_id",
609+
name="HttpPlugin-GetAsync",
610+
response_id="response_id",
611+
type="response.function_call_arguments.done",
612+
)
613+
function_behavior = FunctionChoiceBehavior.Auto(filters={"included_plugins": ["SafePlugin"]})
614+
OpenAIWebsocket._current_settings = OpenAIRealtimeExecutionSettings(
615+
instructions="instructions", ai_model_id="gpt-realtime"
616+
)
617+
OpenAIWebsocket._current_settings.function_choice_behavior = function_behavior
618+
OpenAIWebsocket._call_id_to_function_map["call_id"] = "HttpPlugin-GetAsync"
619+
620+
# Register both plugins on kernel
621+
safe_func = kernel_function(name="safe_function", description="safe")(lambda: "safe")
622+
http_func = kernel_function(name="GetAsync", description="http get")(lambda url: url)
623+
kernel.add_function(plugin_name="SafePlugin", function_name="safe_function", function=safe_func)
624+
kernel.add_function(plugin_name="HttpPlugin", function_name="GetAsync", function=http_func)
625+
OpenAIWebsocket._kernel = kernel
626+
627+
events_received = []
628+
with patch.object(OpenAIWebsocket, "_send"):
629+
async for evt in OpenAIWebsocket._parse_function_call_arguments_done(event):
630+
events_received.append(evt)
631+
632+
# The function call event is yielded, then the result should contain the error
633+
assert len(events_received) >= 2
634+
result_event = events_received[-1]
635+
assert isinstance(result_event, RealtimeFunctionResultEvent)
636+
assert "not part of the provided" in str(result_event.function_result.result)
637+
638+
559639
async def test_send_audio(OpenAIWebsocket):
560640
audio_event = RealtimeAudioEvent(
561641
audio=AudioContent(data=b"audio data", mime_type="audio/wav"),

python/tests/unit/kernel/test_kernel.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,93 @@ async def test_invoke_function_call_with_missing_or_unexpected_args(kernel: Kern
626626
), "Expected fallback message not found in chat history."
627627

628628

629+
async def test_invoke_function_call_with_filters_blocks_unallowed_function(kernel: Kernel):
630+
"""Verify that when function_behavior has filters, an unallowed function is blocked."""
631+
tool_call_mock = MagicMock(spec=FunctionCallContent)
632+
tool_call_mock.name = "HttpPlugin-GetAsync"
633+
tool_call_mock.function_name = "GetAsync"
634+
tool_call_mock.plugin_name = "HttpPlugin"
635+
tool_call_mock.arguments = {"url": "http://169.254.169.254/"}
636+
tool_call_mock.ai_model_id = None
637+
tool_call_mock.metadata = {}
638+
tool_call_mock.index = 0
639+
tool_call_mock.id = "test_id"
640+
641+
chat_history = ChatHistory()
642+
643+
safe_func_meta = KernelFunctionMetadata(name="safe_function", is_prompt=False, plugin_name="SafePlugin")
644+
function_behavior = FunctionChoiceBehavior.Auto(filters={"included_plugins": ["SafePlugin"]})
645+
646+
with patch("semantic_kernel.kernel.Kernel.get_list_of_function_metadata", return_value=[safe_func_meta]):
647+
await kernel.invoke_function_call(
648+
function_call=tool_call_mock,
649+
chat_history=chat_history,
650+
function_behavior=function_behavior,
651+
)
652+
653+
# The function should have been blocked — an error message should be in chat history
654+
assert len(chat_history.messages) == 1
655+
assert "not allowed" in str(chat_history.messages[0].items[0].result) or "not part of the provided" in str(
656+
chat_history.messages[0].items[0].result
657+
)
658+
659+
660+
async def test_invoke_function_call_with_filters_allows_matching_function(kernel: Kernel, get_tool_call_mock):
661+
"""Verify that when function_behavior has filters, an allowed function proceeds (not blocked)."""
662+
tool_call_mock = get_tool_call_mock
663+
chat_history_mock = MagicMock(spec=ChatHistory)
664+
665+
func_meta = KernelFunctionMetadata(
666+
name="function", is_prompt=False, plugin_name="test", fully_qualified_name="test-function"
667+
)
668+
669+
function_behavior = FunctionChoiceBehavior.Auto(filters={"included_plugins": ["test"]})
670+
671+
with (
672+
patch("semantic_kernel.kernel.logger", autospec=True) as logger_mock,
673+
patch("semantic_kernel.kernel.Kernel.get_list_of_function_metadata", return_value=[func_meta]),
674+
patch("semantic_kernel.kernel.Kernel.get_function", return_value=MagicMock()),
675+
):
676+
await kernel.invoke_function_call(
677+
function_call=tool_call_mock,
678+
chat_history=chat_history_mock,
679+
function_behavior=function_behavior,
680+
)
681+
682+
# The warning for missing function_behavior should NOT have been logged
683+
logger_mock.warning.assert_not_called()
684+
# The exception logger should NOT have been called (function was allowed)
685+
logger_mock.exception.assert_not_called()
686+
687+
688+
async def test_invoke_function_call_without_function_behavior_logs_warning(kernel: Kernel, get_tool_call_mock):
689+
"""Verify that calling invoke_function_call without function_behavior logs a warning."""
690+
tool_call_mock = get_tool_call_mock
691+
chat_history_mock = MagicMock(spec=ChatHistory)
692+
693+
func_mock = AsyncMock(spec=KernelFunction)
694+
func_meta = KernelFunctionMetadata(name="function", is_prompt=False)
695+
func_mock.metadata = func_meta
696+
func_mock.name = "function"
697+
func_mock.parameters = []
698+
func_result = FunctionResult(value="Function result", function=func_meta)
699+
func_mock.invoke = MagicMock(return_value=func_result)
700+
701+
with (
702+
patch("semantic_kernel.kernel.logger", autospec=True) as logger_mock,
703+
patch("semantic_kernel.kernel.Kernel.get_function", return_value=func_mock),
704+
):
705+
await kernel.invoke_function_call(
706+
function_call=tool_call_mock,
707+
chat_history=chat_history_mock,
708+
# function_behavior intentionally omitted
709+
)
710+
711+
logger_mock.warning.assert_called_once()
712+
warning_msg = logger_mock.warning.call_args[0][0]
713+
assert "without function_behavior" in warning_msg
714+
715+
629716
# endregion
630717
# region Plugins
631718

0 commit comments

Comments
 (0)