Skip to content

Commit 351742d

Browse files
committed
ci: fix issues that surfaced in CI
1 parent 673a675 commit 351742d

File tree

7 files changed

+33
-27
lines changed

7 files changed

+33
-27
lines changed

examples/agent_patterns/human_in_the_loop.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import asyncio
1111
import json
1212

13-
from agents import Agent, Runner, RunState, function_tool
13+
from agents import Agent, Runner, RunState, ToolApprovalItem, function_tool
1414

1515

1616
@function_tool
@@ -101,17 +101,20 @@ async def main():
101101

102102
# Reading state from file (demonstrating deserialization)
103103
print("Loading state from result.json")
104-
with open("result.json", "r") as f:
104+
with open("result.json") as f:
105105
stored_state_json = json.load(f)
106106

107107
state = RunState.from_json(agent, stored_state_json)
108108

109109
# Process each interruption
110110
for interruption in result.interruptions:
111-
print(f"\nTool call details:")
111+
if not isinstance(interruption, ToolApprovalItem):
112+
continue
113+
114+
print("\nTool call details:")
112115
print(f" Agent: {interruption.agent.name}")
113-
print(f" Tool: {interruption.raw_item.name}") # type: ignore
114-
print(f" Arguments: {interruption.raw_item.arguments}") # type: ignore
116+
print(f" Tool: {interruption.raw_item.name}")
117+
print(f" Arguments: {interruption.raw_item.arguments}")
115118

116119
confirmed = await confirm("\nDo you approve this tool call?")
117120

examples/agent_patterns/human_in_the_loop_stream.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import asyncio
1212

13-
from agents import Agent, Runner, function_tool
13+
from agents import Agent, Runner, ToolApprovalItem, function_tool
1414

1515

1616
async def _needs_temperature_approval(_ctx, params, _call_id) -> bool:
@@ -89,10 +89,13 @@ async def main():
8989
state = result.to_state()
9090

9191
for interruption in result.interruptions:
92-
print(f"\nTool call details:")
92+
if not isinstance(interruption, ToolApprovalItem):
93+
continue
94+
95+
print("\nTool call details:")
9396
print(f" Agent: {interruption.agent.name}")
94-
print(f" Tool: {interruption.raw_item.name}") # type: ignore
95-
print(f" Arguments: {interruption.raw_item.arguments}") # type: ignore
97+
print(f" Tool: {interruption.raw_item.name}")
98+
print(f" Arguments: {interruption.raw_item.arguments}")
9699

97100
confirmed = await confirm("\nDo you approve this tool call?")
98101

src/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def enable_verbose_stdout_logging():
246246
"RunItem",
247247
"HandoffCallItem",
248248
"HandoffOutputItem",
249+
"ToolApprovalItem",
249250
"ToolCallItem",
250251
"ToolCallOutputItem",
251252
"ReasoningItem",
@@ -262,6 +263,7 @@ def enable_verbose_stdout_logging():
262263
"RunResult",
263264
"RunResultStreaming",
264265
"RunConfig",
266+
"RunState",
265267
"RawResponsesStreamEvent",
266268
"RunItemStreamEvent",
267269
"AgentUpdatedStreamEvent",

src/agents/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import warnings
88
from dataclasses import dataclass, field
9-
from typing import Any, Callable, Generic, cast, get_args
9+
from typing import Any, Callable, Generic, Union, cast, get_args
1010

1111
from openai.types.responses import (
1212
ResponseCompletedEvent,
@@ -537,7 +537,7 @@ async def run(
537537
context = run_state._context.context
538538
else:
539539
# Keep original user input separate from session-prepared input
540-
raw_input = cast(str | list[TResponseInputItem], input)
540+
raw_input = cast(Union[str, list[TResponseInputItem]], input)
541541
original_user_input = raw_input
542542
prepared_input = await self._prepare_input_with_session(
543543
raw_input, session, run_config.session_input_callback
@@ -901,7 +901,7 @@ def run_streamed(
901901
# Use context wrapper from RunState
902902
context_wrapper = cast(RunContextWrapper[TContext], run_state._context)
903903
else:
904-
input_for_result = cast(str | list[TResponseInputItem], input)
904+
input_for_result = cast(Union[str, list[TResponseInputItem]], input)
905905
context_wrapper = RunContextWrapper(context=context) # type: ignore
906906

907907
streamed_result = RunResultStreaming(

src/agents/run_state.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@
88

99
from typing_extensions import TypeVar
1010

11+
from ._run_impl import NextStepInterruption
1112
from .exceptions import UserError
13+
from .items import ToolApprovalItem
1214
from .logger import logger
1315
from .run_context import RunContextWrapper
1416
from .usage import Usage
1517

1618
if TYPE_CHECKING:
17-
from ._run_impl import NextStepInterruption
1819
from .agent import Agent
1920
from .guardrail import InputGuardrailResult, OutputGuardrailResult
20-
from .items import ModelResponse, RunItem, ToolApprovalItem
21+
from .items import ModelResponse, RunItem
2122

2223
TContext = TypeVar("TContext", default=Any)
2324
TAgent = TypeVar("TAgent", bound="Agent[Any]", default="Agent[Any]")
@@ -105,8 +106,6 @@ def get_interruptions(self) -> list[RunItem]:
105106
Returns:
106107
List of tool approval items awaiting approval, or empty list if no interruptions.
107108
"""
108-
from ._run_impl import NextStepInterruption
109-
110109
if self._current_step is None or not isinstance(self._current_step, NextStepInterruption):
111110
return []
112111
return self._current_step.interruptions
@@ -235,8 +234,6 @@ def to_json(self) -> dict[str, Any]:
235234

236235
def _serialize_current_step(self) -> dict[str, Any] | None:
237236
"""Serialize the current step if it's an interruption."""
238-
from ._run_impl import NextStepInterruption
239-
240237
if self._current_step is None or not isinstance(self._current_step, NextStepInterruption):
241238
return None
242239

@@ -245,10 +242,15 @@ def _serialize_current_step(self) -> dict[str, Any] | None:
245242
"interruptions": [
246243
{
247244
"type": "tool_approval_item",
248-
"rawItem": item.raw_item.model_dump(exclude_unset=True),
245+
"rawItem": (
246+
item.raw_item.model_dump(exclude_unset=True)
247+
if hasattr(item.raw_item, "model_dump")
248+
else item.raw_item
249+
),
249250
"agent": {"name": item.agent.name},
250251
}
251252
for item in self._current_step.interruptions
253+
if isinstance(item, ToolApprovalItem)
252254
],
253255
}
254256

@@ -366,10 +368,7 @@ def from_string(initial_agent: Agent[Any], state_string: str) -> RunState[Any, A
366368
if current_step_data and current_step_data.get("type") == "next_step_interruption":
367369
from openai.types.responses import ResponseFunctionToolCall
368370

369-
from ._run_impl import NextStepInterruption
370-
from .items import ToolApprovalItem
371-
372-
interruptions = []
371+
interruptions: list[RunItem] = []
373372
for item_data in current_step_data.get("interruptions", []):
374373
agent_name = item_data["agent"]["name"]
375374
agent = agent_map.get(agent_name)
@@ -458,10 +457,7 @@ def from_json(
458457
if current_step_data and current_step_data.get("type") == "next_step_interruption":
459458
from openai.types.responses import ResponseFunctionToolCall
460459

461-
from ._run_impl import NextStepInterruption
462-
from .items import ToolApprovalItem
463-
464-
interruptions = []
460+
interruptions: list[RunItem] = []
465461
for item_data in current_step_data.get("interruptions", []):
466462
agent_name = item_data["agent"]["name"]
467463
agent = agent_map.get(agent_name)

tests/extensions/memory/test_advanced_sqlite_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def create_mock_run_result(
7474
tool_output_guardrail_results=[],
7575
context_wrapper=context_wrapper,
7676
_last_agent=agent,
77+
interruptions=[],
7778
)
7879

7980

tests/test_result_cast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def create_run_result(final_output: Any) -> RunResult:
1818
tool_output_guardrail_results=[],
1919
_last_agent=Agent(name="test"),
2020
context_wrapper=RunContextWrapper(context=None),
21+
interruptions=[],
2122
)
2223

2324

0 commit comments

Comments
 (0)