Skip to content

Commit dfe5e36

Browse files
committed
feat: Support module level query func
1 parent df192c5 commit dfe5e36

File tree

4 files changed

+1155
-132
lines changed

4 files changed

+1155
-132
lines changed

py/src/braintrust/wrappers/claude_agent_sdk/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919

2020
from braintrust.logger import NOOP_SPAN, current_span, init_logger
2121

22-
from ._wrapper import _create_client_wrapper_class, _create_tool_wrapper_class, _wrap_tool_factory
22+
from ._wrapper import (
23+
_create_client_wrapper_class,
24+
_create_tool_wrapper_class,
25+
_wrap_query_function,
26+
_wrap_tool_factory,
27+
)
2328

2429
logger = logging.getLogger(__name__)
2530

@@ -67,9 +72,11 @@ def setup_claude_agent_sdk(
6772
import claude_agent_sdk
6873

6974
original_client = claude_agent_sdk.ClaudeSDKClient if hasattr(claude_agent_sdk, "ClaudeSDKClient") else None
75+
original_query_fn = claude_agent_sdk.query if hasattr(claude_agent_sdk, "query") else None
7076
original_tool_class = claude_agent_sdk.SdkMcpTool if hasattr(claude_agent_sdk, "SdkMcpTool") else None
7177
original_tool_fn = claude_agent_sdk.tool if hasattr(claude_agent_sdk, "tool") else None
7278

79+
wrapped_client = None
7380
if original_client:
7481
wrapped_client = _create_client_wrapper_class(original_client)
7582
claude_agent_sdk.ClaudeSDKClient = wrapped_client
@@ -79,6 +86,15 @@ def setup_claude_agent_sdk(
7986
if getattr(module, "ClaudeSDKClient", None) is original_client:
8087
setattr(module, "ClaudeSDKClient", wrapped_client)
8188

89+
if original_query_fn and wrapped_client:
90+
wrapped_query_fn = _wrap_query_function(original_query_fn, wrapped_client)
91+
claude_agent_sdk.query = wrapped_query_fn
92+
93+
for module in list(sys.modules.values()):
94+
if module and hasattr(module, "query"):
95+
if getattr(module, "query", None) is original_query_fn:
96+
setattr(module, "query", wrapped_query_fn)
97+
8298
if original_tool_class:
8399
wrapped_tool_class = _create_tool_wrapper_class(original_tool_class)
84100
claude_agent_sdk.SdkMcpTool = wrapped_tool_class

0 commit comments

Comments
 (0)