-
Notifications
You must be signed in to change notification settings - Fork 53
Python trace scoring candidate #1278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9d8f149
95c21c8
4036939
d02704a
942b49c
64e1070
e734a6e
355cf55
fda9d29
fcd8f91
f0f93af
b13ee06
c0035f6
77b26a9
59ee9b8
4c1bc79
2eeeb46
30e5217
f67f9e8
7807091
bc236ef
9e864e2
efea2cb
b5d117b
a34741d
59f6c7b
b1dc350
f67aba5
dcc7dd7
7cf3a5a
0300f07
469fa20
5295cd4
b5b3511
c314138
8eb6d49
3ff1c5d
b252963
cd7d04e
3b4653a
56c5bf9
38dcd00
087e2a2
ddccd08
5ca164f
4c789a0
ca8b861
0fc8400
c5e43fc
863b5e6
927df14
7992b74
e721910
98ae957
db9ad2a
763dab6
3fca8aa
6324d80
53e6380
ebf9524
91dcaa1
bc1f473
f7a8beb
9cec760
d2226f7
d00c30c
a53b21a
d0c375c
9eb327e
c27c538
25322d2
62e8955
c53ff50
e83fcb1
e43c61e
f2ae38f
c505f2a
26e0658
61ba617
28cfbe8
1314573
67abadc
a6d1dd4
499e6fc
c630564
d74224f
5c18fe4
e71bf48
b99f266
4245dae
c62769e
4a935ca
5d5ab97
a7b13f8
8bb623d
7d3cad9
c97bf75
a815eb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1280,6 +1280,29 @@ async def _run_evaluator_internal( | |
| filters: list[Filter], | ||
| stream: Callable[[SSEProgressEvent], None] | None = None, | ||
| state: BraintrustState | None = None, | ||
| ): | ||
| # Start span cache for this eval (it's disabled by default to avoid temp files outside of evals) | ||
| if state is None: | ||
| from braintrust.logger import _internal_get_global_state | ||
|
|
||
| state = _internal_get_global_state() | ||
|
|
||
| state.span_cache.start() | ||
| try: | ||
| return await _run_evaluator_internal_impl(experiment, evaluator, position, filters, stream, state) | ||
| finally: | ||
| # Clean up disk-based span cache after eval completes and stop caching | ||
| state.span_cache.dispose() | ||
| state.span_cache.stop() | ||
|
|
||
|
|
||
| async def _run_evaluator_internal_impl( | ||
| experiment, | ||
| evaluator: Evaluator, | ||
| position: int | None, | ||
| filters: list[Filter], | ||
| stream: Callable[[SSEProgressEvent], None] | None = None, | ||
| state: BraintrustState | None = None, | ||
| ): | ||
| event_loop = asyncio.get_event_loop() | ||
|
|
||
|
|
@@ -1290,11 +1313,13 @@ async def await_or_run_scorer(root_span, scorer, name, **kwargs): | |
| {**parent_propagated}, | ||
| {"span_attributes": {"purpose": "scorer"}}, | ||
| ) | ||
| # Strip trace from logged input - it's internal plumbing that shouldn't appear in spans | ||
| logged_input = {k: v for k, v in kwargs.items() if k != "trace"} | ||
| with root_span.start_span( | ||
| name=name, | ||
| span_attributes={"type": SpanTypeAttribute.SCORE, "purpose": "scorer"}, | ||
| propagated_event=merged_propagated, | ||
| input=dict(**kwargs), | ||
| input=logged_input, | ||
| ) as span: | ||
| score = scorer | ||
| if hasattr(scorer, "eval_async"): | ||
|
|
@@ -1415,6 +1440,77 @@ def report_progress(event: TaskProgressEvent): | |
| tags = hooks.tags if hooks.tags else None | ||
| root_span.log(output=output, metadata=metadata, tags=tags) | ||
|
|
||
| # Create trace object for scorers | ||
| from braintrust.trace import LocalTrace | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don't think we have a circ. dep so i'd move this to the top of the file |
||
|
|
||
| async def ensure_spans_flushed(): | ||
| # Flush native Braintrust spans | ||
| if experiment: | ||
| await asyncio.get_event_loop().run_in_executor( | ||
| None, lambda: experiment.state.flush() | ||
| ) | ||
| elif state: | ||
| await asyncio.get_event_loop().run_in_executor(None, lambda: state.flush()) | ||
| else: | ||
| from braintrust.logger import flush as flush_logger | ||
|
|
||
| await asyncio.get_event_loop().run_in_executor(None, flush_logger) | ||
|
|
||
| # Also flush OTEL spans if registered | ||
| if state: | ||
| await state.flush_otel() | ||
|
|
||
| experiment_id = None | ||
| if experiment: | ||
| try: | ||
| experiment_id = experiment.id | ||
| except: | ||
| experiment_id = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #nit |
||
|
|
||
| trace = None | ||
| if state or experiment: | ||
| # Get the state to use | ||
| trace_state = state | ||
| if not trace_state and experiment: | ||
| trace_state = experiment.state | ||
| if not trace_state: | ||
| # Fall back to global state | ||
| from braintrust.logger import _internal_get_global_state | ||
|
|
||
| trace_state = _internal_get_global_state() | ||
|
|
||
| # Access root_span_id from the concrete SpanImpl instance | ||
| # The Span interface doesn't expose this but SpanImpl has it | ||
| root_span_id_value = getattr(root_span, "root_span_id", root_span.id) | ||
|
|
||
| # Check if there's a parent in the context to determine object_type and object_id | ||
| from braintrust.span_identifier_v3 import SpanComponentsV3, span_object_type_v3_to_typed_string | ||
|
|
||
| parent_str = trace_state.current_parent.get() | ||
| parent_components = None | ||
| if parent_str: | ||
| try: | ||
| parent_components = SpanComponentsV3.from_str(parent_str) | ||
| except Exception: | ||
| # If parsing fails, parent_components stays None | ||
| pass | ||
|
|
||
| # Determine object_type and object_id based on parent or experiment | ||
| if parent_components: | ||
| trace_object_type = span_object_type_v3_to_typed_string(parent_components.object_type) | ||
| trace_object_id = parent_components.object_id or "" | ||
| else: | ||
| trace_object_type = "experiment" | ||
| trace_object_id = experiment_id or "" | ||
|
|
||
| trace = LocalTrace( | ||
| object_type=trace_object_type, | ||
| object_id=trace_object_id, | ||
| root_span_id=root_span_id_value, | ||
| ensure_spans_flushed=ensure_spans_flushed, | ||
| state=trace_state, | ||
| ) | ||
|
|
||
| score_promises = [ | ||
| asyncio.create_task( | ||
| await_or_run_scorer( | ||
|
|
@@ -1426,6 +1522,7 @@ def report_progress(event: TaskProgressEvent): | |
| "expected": datum.expected, | ||
| "metadata": metadata, | ||
| "output": output, | ||
| "trace": trace, | ||
| }, | ||
| ) | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from sseclient import SSEClient | ||
|
|
||
| from .._generated_types import FunctionTypeEnum | ||
| from ..logger import Exportable, get_span_parent_object, login, proxy_conn | ||
| from ..logger import Exportable, _internal_get_global_state, get_span_parent_object, login, proxy_conn | ||
| from ..util import response_raise_for_status | ||
| from .constants import INVOKE_API_VERSION | ||
| from .stream import BraintrustInvokeError, BraintrustStream | ||
|
|
@@ -243,6 +243,8 @@ def init_function(project_name: str, slug: str, version: str | None = None): | |
| :param version: Optional version of the function to use. Defaults to latest. | ||
| :return: A function that can be used as a task or scorer. | ||
| """ | ||
| # Disable span cache since remote function spans won't be in the local cache | ||
| _internal_get_global_state().span_cache.disable() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. may be too indirect. if the remote function spans are not in the cache what's the big deal in checking the cache? |
||
|
|
||
| def f(*args: Any, **kwargs: Any) -> Any: | ||
| if len(args) > 0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| """Tests for the invoke module, particularly init_function.""" | ||
|
|
||
|
|
||
| from braintrust.functions.invoke import init_function | ||
| from braintrust.logger import _internal_get_global_state, _internal_reset_global_state | ||
|
|
||
|
|
||
| class TestInitFunction: | ||
| """Tests for init_function.""" | ||
|
|
||
| def setup_method(self): | ||
| """Reset state before each test.""" | ||
| _internal_reset_global_state() | ||
|
|
||
| def teardown_method(self): | ||
| """Clean up after each test.""" | ||
| _internal_reset_global_state() | ||
|
|
||
| def test_init_function_disables_span_cache(self): | ||
| """Test that init_function disables the span cache.""" | ||
| state = _internal_get_global_state() | ||
|
|
||
| # Cache should be disabled by default (it's only enabled during evals) | ||
| assert state.span_cache.disabled is True | ||
|
|
||
| # Enable the cache (simulating what happens during eval) | ||
| state.span_cache.start() | ||
| assert state.span_cache.disabled is False | ||
|
|
||
| # Call init_function | ||
| f = init_function("test-project", "test-function") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is a bit odd that loading a function will cause a cache to be disabled. |
||
|
|
||
| # Cache should now be disabled (init_function explicitly disables it) | ||
| assert state.span_cache.disabled is True | ||
| assert f.__name__ == "init_function-test-project-test-function-latest" | ||
|
|
||
| def test_init_function_with_version(self): | ||
| """Test that init_function creates a function with the correct name including version.""" | ||
| f = init_function("my-project", "my-scorer", version="v1") | ||
| assert f.__name__ == "init_function-my-project-my-scorer-v1" | ||
|
|
||
| def test_init_function_without_version_uses_latest(self): | ||
| """Test that init_function uses 'latest' in name when version not specified.""" | ||
| f = init_function("my-project", "my-scorer") | ||
| assert f.__name__ == "init_function-my-project-my-scorer-latest" | ||
|
|
||
| def test_init_function_permanently_disables_cache(self): | ||
| """Test that init_function permanently disables the cache (can't be re-enabled).""" | ||
| state = _internal_get_global_state() | ||
|
|
||
| # Enable the cache | ||
| state.span_cache.start() | ||
| assert state.span_cache.disabled is False | ||
|
|
||
| # Call init_function | ||
| init_function("test-project", "test-function") | ||
| assert state.span_cache.disabled is True | ||
|
|
||
| # Try to start again - should still be disabled because of explicit disable | ||
| state.span_cache.start() | ||
| assert state.span_cache.disabled is True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a user why/when would trace ever be none?