Skip to content

Commit 6e6064f

Browse files
authored
iterative structured output (#291)
1 parent 7cfbad3 commit 6e6064f

File tree

15 files changed

+245
-58
lines changed

15 files changed

+245
-58
lines changed

src/strands/agent/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
from concurrent.futures import ThreadPoolExecutor
1818
from threading import Thread
19-
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
19+
from typing import Any, AsyncIterator, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union, cast
2020
from uuid import uuid4
2121

2222
from opentelemetry import trace
@@ -423,7 +423,12 @@ def structured_output(self, output_model: Type[T], prompt: Optional[str] = None)
423423
messages.append({"role": "user", "content": [{"text": prompt}]})
424424

425425
# get the structured output from the model
426-
return self.model.structured_output(output_model, messages, self.callback_handler)
426+
events = self.model.structured_output(output_model, messages)
427+
for event in events:
428+
if "callback" in event:
429+
self.callback_handler(**cast(dict, event["callback"]))
430+
431+
return event["output"]
427432

428433
async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
429434
"""Process a natural language prompt and yield events as an async iterator.

src/strands/models/anthropic.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, Callable, Iterable, Optional, Type, TypedDict, TypeVar, cast
10+
from typing import Any, Generator, Iterable, Optional, Type, TypedDict, TypeVar, Union, cast
1111

1212
import anthropic
1313
from pydantic import BaseModel
1414
from typing_extensions import Required, Unpack, override
1515

1616
from ..event_loop.streaming import process_stream
17-
from ..handlers.callback_handler import PrintingCallbackHandler
1817
from ..tools import convert_pydantic_to_tool_spec
1918
from ..types.content import ContentBlock, Messages
2019
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
@@ -378,24 +377,24 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
378377

379378
@override
380379
def structured_output(
381-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
382-
) -> T:
380+
self, output_model: Type[T], prompt: Messages
381+
) -> Generator[dict[str, Union[T, Any]], None, None]:
383382
"""Get structured output from the model.
384383
385384
Args:
386385
output_model(Type[BaseModel]): The output model to use for the agent.
387386
prompt(Messages): The prompt messages to use for the agent.
388-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
387+
388+
Yields:
389+
Model events with the last being the structured output.
389390
"""
390-
callback_handler = callback_handler or PrintingCallbackHandler()
391391
tool_spec = convert_pydantic_to_tool_spec(output_model)
392392

393393
response = self.converse(messages=prompt, tool_specs=[tool_spec])
394394
for event in process_stream(response, prompt):
395-
if "callback" in event:
396-
callback_handler(**event["callback"])
397-
else:
398-
stop_reason, messages, _, _ = event["stop"]
395+
yield event
396+
397+
stop_reason, messages, _, _ = event["stop"]
399398

400399
if stop_reason != "tool_use":
401400
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
@@ -413,4 +412,4 @@ def structured_output(
413412
if output_response is None:
414413
raise ValueError("No valid tool use or tool use input was found in the Anthropic response.")
415414

416-
return output_model(**output_response)
415+
yield {"output": output_model(**output_response)}

src/strands/models/bedrock.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import logging
88
import os
9-
from typing import Any, Callable, Iterable, List, Literal, Optional, Type, TypeVar, cast
9+
from typing import Any, Generator, Iterable, List, Literal, Optional, Type, TypeVar, Union, cast
1010

1111
import boto3
1212
from botocore.config import Config as BotocoreConfig
@@ -15,7 +15,6 @@
1515
from typing_extensions import TypedDict, Unpack, override
1616

1717
from ..event_loop.streaming import process_stream
18-
from ..handlers.callback_handler import PrintingCallbackHandler
1918
from ..tools import convert_pydantic_to_tool_spec
2019
from ..types.content import Messages
2120
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
@@ -521,24 +520,24 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
521520

522521
@override
523522
def structured_output(
524-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
525-
) -> T:
523+
self, output_model: Type[T], prompt: Messages
524+
) -> Generator[dict[str, Union[T, Any]], None, None]:
526525
"""Get structured output from the model.
527526
528527
Args:
529528
output_model(Type[BaseModel]): The output model to use for the agent.
530529
prompt(Messages): The prompt messages to use for the agent.
531-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
530+
531+
Yields:
532+
Model events with the last being the structured output.
532533
"""
533-
callback_handler = callback_handler or PrintingCallbackHandler()
534534
tool_spec = convert_pydantic_to_tool_spec(output_model)
535535

536536
response = self.converse(messages=prompt, tool_specs=[tool_spec])
537537
for event in process_stream(response, prompt):
538-
if "callback" in event:
539-
callback_handler(**event["callback"])
540-
else:
541-
stop_reason, messages, _, _ = event["stop"]
538+
yield event
539+
540+
stop_reason, messages, _, _ = event["stop"]
542541

543542
if stop_reason != "tool_use":
544543
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
@@ -556,4 +555,4 @@ def structured_output(
556555
if output_response is None:
557556
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
558557

559-
return output_model(**output_response)
558+
yield {"output": output_model(**output_response)}

src/strands/models/litellm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import json
77
import logging
8-
from typing import Any, Callable, Optional, Type, TypedDict, TypeVar, cast
8+
from typing import Any, Generator, Optional, Type, TypedDict, TypeVar, Union, cast
99

1010
import litellm
1111
from litellm.utils import supports_response_schema
@@ -105,15 +105,16 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
105105

106106
@override
107107
def structured_output(
108-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
109-
) -> T:
108+
self, output_model: Type[T], prompt: Messages
109+
) -> Generator[dict[str, Union[T, Any]], None, None]:
110110
"""Get structured output from the model.
111111
112112
Args:
113113
output_model(Type[BaseModel]): The output model to use for the agent.
114114
prompt(Messages): The prompt messages to use for the agent.
115-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
116115
116+
Yields:
117+
Model events with the last being the structured output.
117118
"""
118119
# The LiteLLM `Client` inits with Chat().
119120
# Chat() inits with self.completions
@@ -136,7 +137,8 @@ def structured_output(
136137
# Parse the tool call content as JSON
137138
tool_call_data = json.loads(choice.message.content)
138139
# Instantiate the output model with the parsed data
139-
return output_model(**tool_call_data)
140+
yield {"output": output_model(**tool_call_data)}
141+
return
140142
except (json.JSONDecodeError, TypeError, ValueError) as e:
141143
raise ValueError(f"Failed to parse or load content into model: {e}") from e
142144

src/strands/models/llamaapi.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import logging
1010
import mimetypes
11-
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast
11+
from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast
1212

1313
import llama_api_client
1414
from llama_api_client import LlamaAPIClient
@@ -390,14 +390,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
390390

391391
@override
392392
def structured_output(
393-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
394-
) -> T:
393+
self, output_model: Type[T], prompt: Messages
394+
) -> Generator[dict[str, Union[T, Any]], None, None]:
395395
"""Get structured output from the model.
396396
397397
Args:
398398
output_model(Type[BaseModel]): The output model to use for the agent.
399399
prompt(Messages): The prompt messages to use for the agent.
400-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
400+
401+
Yields:
402+
Model events with the last being the structured output.
401403
402404
Raises:
403405
NotImplementedError: Structured output is not currently supported for LlamaAPI models.

src/strands/models/ollama.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import json
77
import logging
8-
from typing import Any, Callable, Iterable, Optional, Type, TypeVar, cast
8+
from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union, cast
99

1010
from ollama import Client as OllamaClient
1111
from pydantic import BaseModel
@@ -316,14 +316,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
316316

317317
@override
318318
def structured_output(
319-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
320-
) -> T:
319+
self, output_model: Type[T], prompt: Messages
320+
) -> Generator[dict[str, Union[T, Any]], None, None]:
321321
"""Get structured output from the model.
322322
323323
Args:
324324
output_model(Type[BaseModel]): The output model to use for the agent.
325325
prompt(Messages): The prompt messages to use for the agent.
326-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
326+
327+
Yields:
328+
Model events with the last being the structured output.
327329
"""
328330
formatted_request = self.format_request(messages=prompt)
329331
formatted_request["format"] = output_model.model_json_schema()
@@ -332,6 +334,6 @@ def structured_output(
332334

333335
try:
334336
content = response.message.content.strip()
335-
return output_model.model_validate_json(content)
337+
yield {"output": output_model.model_validate_json(content)}
336338
except Exception as e:
337339
raise ValueError(f"Failed to parse or load content into model: {e}") from e

src/strands/models/openai.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import logging
7-
from typing import Any, Callable, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, cast
7+
from typing import Any, Generator, Iterable, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
88

99
import openai
1010
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
@@ -133,14 +133,16 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
133133

134134
@override
135135
def structured_output(
136-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
137-
) -> T:
136+
self, output_model: Type[T], prompt: Messages
137+
) -> Generator[dict[str, Union[T, Any]], None, None]:
138138
"""Get structured output from the model.
139139
140140
Args:
141141
output_model(Type[BaseModel]): The output model to use for the agent.
142142
prompt(Messages): The prompt messages to use for the agent.
143-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
143+
144+
Yields:
145+
Model events with the last being the structured output.
144146
"""
145147
response: ParsedChatCompletion = self.client.beta.chat.completions.parse( # type: ignore
146148
model=self.get_config()["model_id"],
@@ -159,6 +161,6 @@ def structured_output(
159161
break
160162

161163
if parsed:
162-
return parsed
164+
yield {"output": parsed}
163165
else:
164166
raise ValueError("No valid tool use or tool use input was found in the OpenAI response.")

src/strands/types/models/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import abc
44
import logging
5-
from typing import Any, Callable, Iterable, Optional, Type, TypeVar
5+
from typing import Any, Generator, Iterable, Optional, Type, TypeVar, Union
66

77
from pydantic import BaseModel
88

@@ -45,17 +45,16 @@ def get_config(self) -> Any:
4545
@abc.abstractmethod
4646
# pragma: no cover
4747
def structured_output(
48-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
49-
) -> T:
48+
self, output_model: Type[T], prompt: Messages
49+
) -> Generator[dict[str, Union[T, Any]], None, None]:
5050
"""Get structured output from the model.
5151
5252
Args:
5353
output_model(Type[BaseModel]): The output model to use for the agent.
5454
prompt(Messages): The prompt messages to use for the agent.
55-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
5655
57-
Returns:
58-
The structured output as a serialized instance of the output model.
56+
Yields:
57+
Model events with the last being the structured output.
5958
6059
Raises:
6160
ValidationException: The response format from the model does not match the output_model

src/strands/types/models/openai.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import json
1212
import logging
1313
import mimetypes
14-
from typing import Any, Callable, Optional, Type, TypeVar, cast
14+
from typing import Any, Generator, Optional, Type, TypeVar, Union, cast
1515

1616
from pydantic import BaseModel
1717
from typing_extensions import override
@@ -295,13 +295,15 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
295295

296296
@override
297297
def structured_output(
298-
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
299-
) -> T:
298+
self, output_model: Type[T], prompt: Messages
299+
) -> Generator[dict[str, Union[T, Any]], None, None]:
300300
"""Get structured output from the model.
301301
302302
Args:
303303
output_model(Type[BaseModel]): The output model to use for the agent.
304304
prompt(Messages): The prompt to use for the agent.
305-
callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
305+
306+
Yields:
307+
Model events with the last being the structured output.
306308
"""
307-
return output_model()
309+
yield {"output": output_model()}

tests/strands/agent/test_agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -898,17 +898,15 @@ class User(BaseModel):
898898
def test_agent_method_structured_output(agent):
899899
# Mock the structured_output method on the model
900900
expected_user = User(name="Jane Doe", age=30, email="[email protected]")
901-
agent.model.structured_output = unittest.mock.Mock(return_value=expected_user)
901+
agent.model.structured_output = unittest.mock.Mock(return_value=[{"output": expected_user}])
902902

903903
prompt = "Jane Doe is 30 years old and her email is [email protected]"
904904

905905
result = agent.structured_output(User, prompt)
906906
assert result == expected_user
907907

908908
# Verify the model's structured_output was called with correct arguments
909-
agent.model.structured_output.assert_called_once_with(
910-
User, [{"role": "user", "content": [{"text": prompt}]}], agent.callback_handler
911-
)
909+
agent.model.structured_output.assert_called_once_with(User, [{"role": "user", "content": [{"text": prompt}]}])
912910

913911

914912
@pytest.mark.asyncio

0 commit comments

Comments
 (0)