Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions py/src/braintrust/wrappers/claude_agent_sdk/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import dataclasses
import logging
import threading
Expand Down Expand Up @@ -235,6 +236,7 @@ async def wrapped_handler(args: Any) -> Any:
class ToolSpanTracker:
def __init__(self):
self._active_spans: dict[str, _ActiveToolSpan] = {}
self._pending_task_link_tool_use_ids: set[str] = set()

def start_tool_spans(self, message: Any, llm_span_export: str | None) -> None:
if llm_span_export is None or not hasattr(message, "content"):
Expand Down Expand Up @@ -278,6 +280,8 @@ def start_tool_spans(self, message: Any, llm_span_export: str | None) -> None:
display_name=parsed_tool_name.display_name,
input=getattr(block, "input", None),
)
if parsed_tool_name.display_name == "Agent":
self._pending_task_link_tool_use_ids.add(tool_use_id)

def finish_tool_spans(self, message: Any) -> None:
if not hasattr(message, "content"):
Expand All @@ -303,6 +307,16 @@ def cleanup(self, end_time: float | None = None, exclude_tool_use_ids: frozenset
def has_active_spans(self) -> bool:
return bool(self._active_spans)

@property
def pending_task_link_tool_use_ids(self) -> frozenset[str]:
return frozenset(self._pending_task_link_tool_use_ids)

def mark_task_started(self, tool_use_id: Any) -> None:
if tool_use_id is None:
return

self._pending_task_link_tool_use_ids.discard(str(tool_use_id))

def acquire_span_for_handler(self, tool_name: Any, args: Any) -> _ActiveToolSpan | None:
parsed_tool_name = _parse_tool_name(tool_name)
candidate_names = list(dict.fromkeys((parsed_tool_name.raw_name, parsed_tool_name.display_name, str(tool_name))))
Expand All @@ -323,6 +337,7 @@ def acquire_span_for_handler(self, tool_name: Any, args: Any) -> _ActiveToolSpan

def _end_tool_span(self, tool_use_id: str, tool_result_block: Any | None = None, end_time: float | None = None) -> None:
active_tool_span = self._active_spans.pop(tool_use_id, None)
self._pending_task_link_tool_use_ids.discard(tool_use_id)
if active_tool_span is None:
return

Expand Down Expand Up @@ -491,7 +506,9 @@ def process(self, message: Any) -> None:
self._active_task_order.append(task_id)
tool_use_id = getattr(message, "tool_use_id", None)
if tool_use_id is not None:
self._task_span_by_tool_use_id[str(tool_use_id)] = task_span
tool_use_id = str(tool_use_id)
self._task_span_by_tool_use_id[tool_use_id] = task_span
self._tool_tracker.mark_task_started(tool_use_id)
else:
update: dict[str, Any] = {}
metadata = self._metadata(message)
Expand Down Expand Up @@ -693,9 +710,12 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:

if message_type == MessageClassName.ASSISTANT:
if llm_tracker.current_span and tool_tracker.has_active_spans:
active_subagent_tool_use_ids = (
task_event_span_tracker.active_tool_use_ids | tool_tracker.pending_task_link_tool_use_ids
)
tool_tracker.cleanup(
end_time=llm_tracker.get_next_start_time(),
exclude_tool_use_ids=task_event_span_tracker.active_tool_use_ids,
exclude_tool_use_ids=active_subagent_tool_use_ids,
)
llm_parent_export = task_event_span_tracker.parent_export_for_message(
message,
Expand Down Expand Up @@ -744,8 +764,15 @@ async def receive_response(self) -> AsyncGenerator[Any, None]:
task_events.append(_serialize_system_message(message))

yield message
except Exception:
raise
except asyncio.CancelledError:
# The CancelledError may come from the subprocess transport
# (e.g., anyio internal cleanup when subagents complete) rather
# than a genuine external cancellation. We suppress it here so
# the response stream ends cleanly. If the caller genuinely
# cancelled the task, they still have pending cancellation
# requests that will fire at their next await point.
if final_results:
span.log(output=final_results[-1])
else:
if final_results:
span.log(output=final_results[-1])
Expand Down
Loading
Loading