Skip to content

Commit b52c904

Browse files
committed
feat: Add reasoning content for openai provider
1 parent 8ffe24b commit b52c904

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

src/strands/models/openai.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
9999
if choice.delta.content:
100100
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
101101

102+
if choice.delta.reasoning_content:
103+
yield {
104+
"chunk_type": "content_delta",
105+
"data_type": "reasoning_content",
106+
"data": choice.delta.reasoning_content,
107+
}
108+
102109
for tool_call in choice.delta.tool_calls or []:
103110
tool_calls.setdefault(tool_call.index, []).append(tool_call)
104111

src/strands/types/models/openai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
232232
"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}}
233233
}
234234

235+
if event["data_type"] == "reasoning_content":
236+
return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}}
237+
235238
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
236239

237240
case "content_stop":

tests/strands/models/test_openai.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,31 +63,46 @@ def test_stream(openai_client, model):
6363
mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
6464
mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
6565
mock_delta_1 = unittest.mock.Mock(
66-
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1]
66+
reasoning_content="<think>",
67+
content=None,
68+
tool_calls=None,
69+
)
70+
mock_delta_2 = unittest.mock.Mock(
71+
reasoning_content="\nOkey, the user just</think>",
72+
content=None,
73+
tool_calls=None,
74+
)
75+
mock_delta_3 = unittest.mock.Mock(
76+
content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1], reasoning_content=None
6777
)
6878

6979
mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
7080
mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
71-
mock_delta_2 = unittest.mock.Mock(
72-
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2]
81+
mock_delta_4 = unittest.mock.Mock(
82+
content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2], reasoning_content=None
7383
)
7484

75-
mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None)
85+
mock_delta_5 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None)
7686

7787
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
7888
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
79-
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)])
80-
mock_event_4 = unittest.mock.Mock()
89+
mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_3)])
90+
mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_4)])
91+
mock_event_5 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_5)])
92+
mock_event_6 = unittest.mock.Mock()
8193

82-
openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
94+
openai_client.chat.completions.create.return_value = iter(
95+
[mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5, mock_event_6]
96+
)
8397

8498
request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
8599
response = model.stream(request)
86-
87100
tru_events = list(response)
88101
exp_events = [
89102
{"chunk_type": "message_start"},
90103
{"chunk_type": "content_start", "data_type": "text"},
104+
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "<think>"},
105+
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "\nOkey, the user just</think>"},
91106
{"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
92107
{"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
93108
{"chunk_type": "content_stop", "data_type": "text"},
@@ -100,15 +115,15 @@ def test_stream(openai_client, model):
100115
{"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
101116
{"chunk_type": "content_stop", "data_type": "tool"},
102117
{"chunk_type": "message_stop", "data": "tool_calls"},
103-
{"chunk_type": "metadata", "data": mock_event_4.usage},
118+
{"chunk_type": "metadata", "data": mock_event_6.usage},
104119
]
105120

106121
assert tru_events == exp_events
107122
openai_client.chat.completions.create.assert_called_once_with(**request)
108123

109124

110125
def test_stream_empty(openai_client, model):
111-
mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
126+
mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None)
112127
mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
113128

114129
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])

0 commit comments

Comments
 (0)