Skip to content

Commit 9fc4f47

Browse files
authored
fix: fix streaming of function/tool call arguments (#239)
1 parent daf20ad commit 9fc4f47

File tree

7 files changed

+204
-27
lines changed

7 files changed

+204
-27
lines changed

aidial_sdk/chat_completion/chunks.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Optional, TypedDict
2+
from typing import Any, Dict, List, Literal, Optional, TypedDict
33

44
from aidial_sdk.chat_completion.enums import FinishReason, Status
55
from aidial_sdk.exceptions import HTTPException as DIALException
@@ -100,6 +100,7 @@ class FunctionToolCallChunk(BaseChunk):
100100
choice_index: int
101101
call_index: int
102102
id: Optional[str]
103+
type: Optional[Literal["function"]]
103104
name: Optional[str]
104105
arguments: Optional[str]
105106

@@ -108,12 +109,14 @@ def __init__(
108109
choice_index: int,
109110
call_index: int,
110111
id: Optional[str],
112+
type: Optional[Literal["function"]],
111113
name: Optional[str],
112114
arguments: Optional[str],
113115
):
114116
self.choice_index = choice_index
115117
self.call_index = call_index
116118
self.id = id
119+
self.type = type
117120
self.name = name
118121
self.arguments = arguments
119122

@@ -122,14 +125,15 @@ def to_dict(self):
122125
"choices": [
123126
{
124127
"index": self.choice_index,
128+
"finish_reason": None,
125129
"delta": {
126130
"content": None,
127131
"tool_calls": [
128132
remove_nones(
129133
{
130134
"index": self.call_index,
131135
"id": self.id,
132-
"type": "function",
136+
"type": self.type,
133137
"function": remove_nones(
134138
{
135139
"name": self.name,
@@ -166,6 +170,7 @@ def to_dict(self):
166170
"choices": [
167171
{
168172
"index": self.choice_index,
173+
"finish_reason": None,
169174
"delta": {
170175
"content": None,
171176
"function_call": remove_nones(

aidial_sdk/chat_completion/function_call.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@ def __init__(self, choice: ChoiceBase):
1515
def create_and_send(
1616
cls, choice: ChoiceBase, name: str, arguments: Optional[str]
1717
) -> "FunctionCall":
18-
return cls(choice)._send_function_call(name=name, arguments=arguments)
18+
return cls(choice)._send_function_call(
19+
create=True, name=name, arguments=arguments
20+
)
1921

2022
def append_arguments(self, arguments: str) -> "FunctionCall":
21-
return self._send_function_call(name=None, arguments=arguments)
23+
return self._send_function_call(
24+
create=False, name=None, arguments=arguments
25+
)
2226

2327
def _send_function_call(
24-
self, name: Optional[str], arguments: Optional[str]
28+
self, *, create: bool, name: Optional[str], arguments: Optional[str]
2529
) -> "FunctionCall":
2630
if not self._choice.opened:
2731
raise runtime_error(
@@ -31,7 +35,7 @@ def _send_function_call(
3135
raise runtime_error(
3236
"Trying to add function call to a closed choice"
3337
)
34-
if self._choice.has_function_call:
38+
if create and self._choice.has_function_call:
3539
raise runtime_error(
3640
"Trying to add function call to a choice which already has a function call"
3741
)

aidial_sdk/chat_completion/function_tool_call.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Literal, Optional
22

33
from aidial_sdk.chat_completion.choice_base import ChoiceBase
44
from aidial_sdk.chat_completion.chunks import FunctionToolCallChunk
@@ -23,14 +23,21 @@ def create_and_send(
2323
arguments: Optional[str],
2424
) -> "FunctionToolCall":
2525
return cls(choice, index)._send_tool_call(
26-
id=id, name=name, arguments=arguments
26+
id=id, type="function", name=name, arguments=arguments
2727
)
2828

2929
def append_arguments(self, arguments: str) -> "FunctionToolCall":
30-
return self._send_tool_call(id=None, name=None, arguments=arguments)
30+
return self._send_tool_call(
31+
id=None, type=None, name=None, arguments=arguments
32+
)
3133

3234
def _send_tool_call(
33-
self, id: Optional[str], name: Optional[str], arguments: Optional[str]
35+
self,
36+
*,
37+
id: Optional[str],
38+
type: Optional[Literal["function"]],
39+
name: Optional[str],
40+
arguments: Optional[str]
3441
) -> "FunctionToolCall":
3542
if not self._choice.opened:
3643
raise runtime_error("Trying to add tool call to an unopened choice")
@@ -41,6 +48,7 @@ def _send_tool_call(
4148
FunctionToolCallChunk(
4249
self._choice.index,
4350
self._index,
51+
type=type,
4452
id=id,
4553
name=name,
4654
arguments=arguments,

tests/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import pytest
2+
3+
pytest.register_assert_rewrite("tests.utils.chunks")

tests/test_function_calling.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
2+
from tests.utils.chunks import (
3+
check_sse_stream,
4+
create_function_call_chunk,
5+
create_single_choice_chunk,
6+
)
7+
from tests.utils.client import create_app_client
8+
9+
10+
class FunctionCaller(ChatCompletion):
11+
async def chat_completion(
12+
self, request: Request, response: Response
13+
) -> None:
14+
response.set_response_id("test_id")
15+
response.set_created(0)
16+
17+
with response.create_single_choice() as choice:
18+
choice.append_content("Test content")
19+
20+
function_call = choice.create_function_call("function_name")
21+
function_call.append_arguments('{"key')
22+
function_call.append_arguments('":"')
23+
function_call.append_arguments('val"}')
24+
25+
26+
def test_function_call_non_streaming():
27+
response = create_app_client(FunctionCaller()).post(
28+
"chat/completions", json={"messages": [], "stream": False}
29+
)
30+
31+
body = response.json()
32+
assert body["choices"][0]["message"]["function_call"] == {
33+
"name": "function_name",
34+
"arguments": '{"key":"val"}',
35+
}
36+
37+
38+
def test_function_call_streaming():
39+
response = create_app_client(FunctionCaller()).post(
40+
"chat/completions", json={"messages": [], "stream": True}
41+
)
42+
43+
check_sse_stream(
44+
response.iter_lines(),
45+
[
46+
create_single_choice_chunk({"role": "assistant"}),
47+
create_single_choice_chunk({"content": "Test content"}),
48+
create_function_call_chunk(name="function_name"),
49+
create_function_call_chunk(arguments='{"key'),
50+
create_function_call_chunk(arguments='":"'),
51+
create_function_call_chunk(arguments='val"}'),
52+
create_single_choice_chunk({}, "function_call"),
53+
],
54+
)

tests/test_tool_calling.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
2+
from tests.utils.chunks import (
3+
check_sse_stream,
4+
create_single_choice_chunk,
5+
create_tool_call_chunk,
6+
)
7+
from tests.utils.client import create_app_client
8+
9+
10+
class ToolCaller(ChatCompletion):
11+
async def chat_completion(
12+
self, request: Request, response: Response
13+
) -> None:
14+
response.set_response_id("test_id")
15+
response.set_created(0)
16+
17+
with response.create_single_choice() as choice:
18+
choice.append_content("Test content")
19+
20+
tool_call1 = choice.create_function_tool_call(
21+
"tool_call_id1", "tool_name"
22+
)
23+
tool_call1.append_arguments('{"key')
24+
tool_call1.append_arguments('":"')
25+
tool_call1.append_arguments('val"}')
26+
27+
choice.create_function_tool_call(
28+
"tool_call_id2", "tool_name", '{"foo":"bar"}'
29+
)
30+
31+
32+
def test_tool_call_non_streaming():
33+
response = create_app_client(ToolCaller()).post(
34+
"chat/completions", json={"messages": [], "stream": False}
35+
)
36+
37+
body = response.json()
38+
assert body["choices"][0]["message"]["tool_calls"] == [
39+
{
40+
"id": "tool_call_id1",
41+
"type": "function",
42+
"function": {
43+
"name": "tool_name",
44+
"arguments": '{"key":"val"}',
45+
},
46+
},
47+
{
48+
"id": "tool_call_id2",
49+
"type": "function",
50+
"function": {
51+
"name": "tool_name",
52+
"arguments": '{"foo":"bar"}',
53+
},
54+
},
55+
]
56+
57+
58+
def test_tool_call_streaming():
59+
response = create_app_client(ToolCaller()).post(
60+
"chat/completions", json={"messages": [], "stream": True}
61+
)
62+
63+
check_sse_stream(
64+
response.iter_lines(),
65+
[
66+
create_single_choice_chunk({"role": "assistant"}),
67+
create_single_choice_chunk({"content": "Test content"}),
68+
create_tool_call_chunk(
69+
0, type="function", id="tool_call_id1", name="tool_name"
70+
),
71+
create_tool_call_chunk(0, arguments='{"key'),
72+
create_tool_call_chunk(0, arguments='":"'),
73+
create_tool_call_chunk(0, arguments='val"}'),
74+
create_tool_call_chunk(
75+
1,
76+
type="function",
77+
id="tool_call_id2",
78+
name="tool_name",
79+
arguments='{"foo":"bar"}',
80+
),
81+
create_single_choice_chunk({}, "tool_calls"),
82+
],
83+
)

tests/utils/chunks.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,29 @@
22
import json
33
from typing import Iterable, Literal, Optional, Union
44

5+
from aidial_sdk.utils.json import remove_nones
6+
57

68
def create_chunk(
79
*,
810
choice_idx: int = 0,
911
delta: dict = {},
1012
finish_reason: Optional[str] = None,
13+
**kwargs,
1114
):
1215
return {
13-
"id": "chatcmpl-AQws8iVykPBIQJfnmCQnMEkTLLUUA",
16+
"id": "test_id",
1417
"object": "chat.completion.chunk",
15-
"created": 1730986196,
16-
"model": "gpt-4o-2024-05-13",
17-
"system_fingerprint": "fp_67802d9a6d",
18+
"created": 0,
1819
"choices": [
1920
{
2021
"index": choice_idx,
2122
"delta": delta,
2223
"finish_reason": finish_reason,
2324
}
2425
],
26+
"usage": None,
27+
**kwargs,
2528
}
2629

2730

@@ -54,23 +57,41 @@ def create_tool_call_chunk(
5457
):
5558
return create_chunk(
5659
delta={
60+
"content": None,
5761
"tool_calls": [
58-
{
59-
"index": idx,
60-
"id": id,
61-
"type": type,
62-
"function": {"name": name, "arguments": arguments},
63-
}
64-
]
62+
remove_nones(
63+
{
64+
"index": idx,
65+
"id": id,
66+
"type": type,
67+
"function": remove_nones(
68+
{"name": name, "arguments": arguments}
69+
),
70+
}
71+
)
72+
],
73+
}
74+
)
75+
76+
77+
def create_function_call_chunk(
78+
*,
79+
name: Optional[str] = None,
80+
arguments: Optional[str] = None,
81+
):
82+
return create_chunk(
83+
delta={
84+
"content": None,
85+
"function_call": remove_nones(
86+
{"name": name, "arguments": arguments}
87+
),
6588
}
6689
)
6790

6891

6992
def _check_sse_line(actual: str, expected: Union[str, dict]):
7093
if isinstance(expected, str):
71-
assert (
72-
actual == expected
73-
), f"actual line != expected line: {actual!r} != {expected!r}"
94+
assert actual == expected
7495
return
7596

7697
assert actual.startswith("data: "), f"Invalid data SSE entry: {actual!r}"
@@ -80,9 +101,8 @@ def _check_sse_line(actual: str, expected: Union[str, dict]):
80101
actual_dict = json.loads(actual)
81102
except json.JSONDecodeError:
82103
raise AssertionError(f"Invalid JSON in data SSE entry: {actual!r}")
83-
assert (
84-
actual_dict == expected
85-
), f"actual json != expected json: {actual_dict!r} != {expected!r}"
104+
105+
assert actual_dict == expected
86106

87107

88108
ExpectedSSEStream = Iterable[Union[str, dict]]

0 commit comments

Comments
 (0)