Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions py/src/braintrust/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,12 @@ def _postprocess_streaming_results(cls, all_results: list[dict[str, Any]]) -> di


class ResponseWrapper:
def __init__(self, create_fn: Callable[..., Any] | None, acreate_fn: Callable[..., Any] | None, name: str = "openai.responses.create"):
def __init__(
self,
create_fn: Callable[..., Any] | None,
acreate_fn: Callable[..., Any] | None,
name: str = "openai.responses.create",
):
self.create_fn = create_fn
self.acreate_fn = acreate_fn
self.name = name
Expand All @@ -359,9 +364,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any:
params = self._parse_params(kwargs)
stream = kwargs.get("stream", False)

span = start_span(
**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)
)
span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params))
should_end = True

try:
Expand All @@ -373,6 +376,7 @@ def create(self, *args: Any, **kwargs: Any) -> Any:
else:
raw_response = create_response
if stream:

def gen():
try:
first = True
Expand Down Expand Up @@ -410,9 +414,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any:
params = self._parse_params(kwargs)
stream = kwargs.get("stream", False)

span = start_span(
**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params)
)
span = start_span(**merge_dicts(dict(name=self.name, span_attributes={"type": SpanTypeAttribute.LLM}), params))
should_end = True

try:
Expand All @@ -424,6 +426,7 @@ async def acreate(self, *args: Any, **kwargs: Any) -> Any:
else:
raw_response = create_response
if stream:

async def gen():
try:
first = True
Expand Down Expand Up @@ -506,7 +509,12 @@ def _postprocess_streaming_results(cls, all_results: list[Any]) -> dict[str, Any

for result in all_results:
usage = getattr(result, "usage", None)
if not usage and hasattr(result, "type") and result.type == "response.completed" and hasattr(result, "response"):
if (
not usage
and hasattr(result, "type")
and result.type == "response.completed"
and hasattr(result, "response")
):
# Handle summaries from completed response if present
if hasattr(result.response, "output") and result.response.output:
for output_item in result.response.output:
Expand Down Expand Up @@ -795,7 +803,9 @@ def create(self, *args: Any, **kwargs: Any) -> Any:
return ResponseWrapper(self.__responses.with_raw_response.create, None).create(*args, **kwargs)

def parse(self, *args: Any, **kwargs: Any) -> Any:
return ResponseWrapper(self.__responses.with_raw_response.parse, None, "openai.responses.parse").create(*args, **kwargs)
return ResponseWrapper(self.__responses.with_raw_response.parse, None, "openai.responses.parse").create(
*args, **kwargs
)


class AsyncResponsesV1Wrapper(NamedWrapper):
Expand All @@ -808,7 +818,9 @@ async def create(self, *args: Any, **kwargs: Any) -> Any:
return AsyncResponseWrapper(response)

async def parse(self, *args: Any, **kwargs: Any) -> Any:
response = await ResponseWrapper(None, self.__responses.with_raw_response.parse, "openai.responses.parse").acreate(*args, **kwargs)
response = await ResponseWrapper(
None, self.__responses.with_raw_response.parse, "openai.responses.parse"
).acreate(*args, **kwargs)
return AsyncResponseWrapper(response)


Expand Down Expand Up @@ -938,7 +950,6 @@ def _parse_metrics_from_usage(usage: Any) -> dict[str, Any]:
return metrics



def prettify_params(params: dict[str, Any]) -> dict[str, Any]:
# Filter out NOT_GIVEN parameters
# https://linear.app/braintrustdata/issue/BRA-2467
Expand Down
24 changes: 22 additions & 2 deletions py/src/braintrust/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]:
raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model")


def _extract_pydantic_fields(
schema: dict[str, dict[str, Any]],
) -> tuple[Any, Any] | tuple[dict[str, Any], dict[str, Any]]:
"""Extract pydantic fields default and description metadata"""
flatten_defaults = {}
flatten_description = {}
schema_items = schema.items()
if len(schema_items) == 1 and "value" in schema:
for _, field_metadata in schema_items:
return (field_metadata.get("default"), field_metadata.get("description"))

for field_name, field_metadata in schema_items:
flatten_defaults[field_name] = field_metadata.get("default")
flatten_description[field_name] = field_metadata.get("description")
return (flatten_defaults, flatten_description)


def validate_parameters(
parameters: dict[str, Any],
parameter_schema: EvalParameters,
Expand Down Expand Up @@ -143,10 +160,13 @@ def parameters_to_json_schema(parameters: EvalParameters) -> dict[str, Any]:
else:
# Pydantic model
try:
pydantic_schema = _pydantic_to_json_schema(schema)
model_defaults, model_descriptions = _extract_pydantic_fields(pydantic_schema.get("properties", {}))
result[name] = {
"type": "data",
"schema": _pydantic_to_json_schema(schema),
# TODO: Extract default and description from pydantic model
"schema": pydantic_schema,
"default": model_defaults,
"description": model_descriptions,
}
except ValueError:
# Not a pydantic model, skip
Expand Down
158 changes: 158 additions & 0 deletions py/src/braintrust/test_pydantic_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from unittest.mock import MagicMock

import pytest
from braintrust.parameters import _extract_pydantic_fields, parameters_to_json_schema

DEFAULT_VALUE = "You are a helpful assistant."
DEFAULT_DESCRIPTION = "System prompt for the model"
SCHEMA_TITLE = "SystemPrompt"
PARAM_NAME = "system_prompt"


def test_extract_single_field_with_default_and_description():
schema = {
"value": {"type": "string", "default": DEFAULT_VALUE, "description": DEFAULT_DESCRIPTION},
}
defaults, descriptions = _extract_pydantic_fields(schema)
assert defaults == DEFAULT_VALUE
assert descriptions == DEFAULT_DESCRIPTION


def test_extract_single_field_missing_default_and_description():
schema = {
"value": {"type": "string"},
}
defaults, descriptions = _extract_pydantic_fields(schema)
assert defaults is None
assert descriptions is None


def test_extract_multi_field():
schema = {
"temperature": {"type": "number", "default": 0.7, "description": "Sampling temperature"},
"max_tokens": {"type": "integer", "default": 1024, "description": "Maximum tokens to generate"},
}
defaults, descriptions = _extract_pydantic_fields(schema)
assert defaults == {"temperature": 0.7, "max_tokens": 1024}
assert descriptions == {"temperature": "Sampling temperature", "max_tokens": "Maximum tokens to generate"}


def test_extract_multi_field_partial_metadata():
schema = {
"temperature": {"type": "number", "default": 0.7},
"max_tokens": {"type": "integer", "description": "Maximum tokens to generate"},
}
defaults, descriptions = _extract_pydantic_fields(schema)
assert defaults == {"temperature": 0.7, "max_tokens": None}
assert descriptions == {"temperature": None, "max_tokens": "Maximum tokens to generate"}


def test_extract_empty_schema():
defaults, descriptions = _extract_pydantic_fields({})
assert defaults == {}
assert descriptions == {}


@pytest.fixture
def v2_model():
def _make(default=DEFAULT_VALUE, description=DEFAULT_DESCRIPTION):
model = MagicMock()
model.model_json_schema.return_value = {
"title": SCHEMA_TITLE,
"type": "object",
"properties": {"value": {"type": "string", "default": default, "description": description}},
}
del model.get
return model

return _make


@pytest.fixture
def v1_model():
def _make(default=DEFAULT_VALUE):
model = MagicMock()
del model.model_json_schema
model.schema.return_value = {
"title": SCHEMA_TITLE,
"type": "object",
"properties": {"value": {"type": "string", "default": default}},
}
del model.get
return model

return _make


@pytest.fixture
def v2_multi_field_model():
def _make():
model = MagicMock()
model.model_json_schema.return_value = {
"title": "ModelConfig",
"type": "object",
"properties": {
"temperature": {"type": "number", "default": 0.7, "description": "Sampling temperature"},
"max_tokens": {"type": "integer", "default": 1024, "description": "Maximum tokens to generate"},
},
}
del model.get
return model

return _make


@pytest.fixture
def v1_multi_field_model():
def _make():
model = MagicMock()
del model.model_json_schema
model.schema.return_value = {
"title": "ModelConfig",
"type": "object",
"properties": {
"temperature": {"type": "number", "default": 0.7},
"max_tokens": {"type": "integer", "default": 1024},
},
}
del model.get
return model

return _make


def test_pydantic_v2_model(v2_model):
schema = parameters_to_json_schema({PARAM_NAME: v2_model()})

assert schema[PARAM_NAME]["type"] == "data"
assert schema[PARAM_NAME]["default"] == DEFAULT_VALUE
assert schema[PARAM_NAME]["description"] == DEFAULT_DESCRIPTION


def test_pydantic_v1_model(v1_model):
schema = parameters_to_json_schema({PARAM_NAME: v1_model()})

assert schema[PARAM_NAME]["type"] == "data"
assert schema[PARAM_NAME]["default"] == DEFAULT_VALUE
assert schema[PARAM_NAME]["description"] is None


def test_pydantic_v2_multi_field_model(v2_multi_field_model):
schema = parameters_to_json_schema({PARAM_NAME: v2_multi_field_model()})

assert schema[PARAM_NAME]["type"] == "data"
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
assert schema[PARAM_NAME]["default"] == {"temperature": 0.7, "max_tokens": 1024}
assert schema[PARAM_NAME]["description"] == {
"temperature": "Sampling temperature",
"max_tokens": "Maximum tokens to generate",
}


def test_pydantic_v1_multi_field_model(v1_multi_field_model):
schema = parameters_to_json_schema({PARAM_NAME: v1_multi_field_model()})

assert schema[PARAM_NAME]["type"] == "data"
assert schema[PARAM_NAME]["schema"]["title"] == "ModelConfig"
assert schema[PARAM_NAME]["default"] == {"temperature": 0.7, "max_tokens": 1024}
assert schema[PARAM_NAME]["description"] == {"temperature": None, "max_tokens": None}
2 changes: 2 additions & 0 deletions py/src/braintrust/wrappers/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def __init__(self, id="test_id", type="message"):
# No spans should be generated from this unit test
assert not memory_logger.pop()


@pytest.mark.vcr
def test_openai_embeddings(memory_logger):
assert not memory_logger.pop()
Expand Down Expand Up @@ -1935,6 +1936,7 @@ def test_auto_instrument_openai(self):
"""Test auto_instrument patches OpenAI, creates spans, and uninstrument works."""
verify_autoinstrument_script("test_auto_openai.py")


class TestZAICompatibleOpenAI:
"""Tests for validating some ZAI compatibility with OpenAI wrapper."""

Expand Down
Loading