diff --git a/integrations/openai-agents-js/src/index.ts b/integrations/openai-agents-js/src/index.ts index 95a18edd5..4a6a565d3 100644 --- a/integrations/openai-agents-js/src/index.ts +++ b/integrations/openai-agents-js/src/index.ts @@ -382,7 +382,8 @@ export class OpenAIAgentsTraceProcessor { if (!data.metrics.completion_tokens && usage.completionTokens) data.metrics.completion_tokens = usage.completionTokens; if (usage.input_tokens_details?.cached_tokens != null) - data.metrics.prompt_cached_tokens = usage.input_tokens_details.cached_tokens; + data.metrics.prompt_cached_tokens = + usage.input_tokens_details.cached_tokens; } return data; diff --git a/integrations/openai-agents-js/src/openai-agents-integration.test.ts b/integrations/openai-agents-js/src/openai-agents-integration.test.ts index 618da167b..75c954942 100644 --- a/integrations/openai-agents-js/src/openai-agents-integration.test.ts +++ b/integrations/openai-agents-js/src/openai-agents-integration.test.ts @@ -908,7 +908,7 @@ describe( output_tokens: 50, total_tokens: 150, input_tokens_details: { - cached_tokens: 80, // check for this later + cached_tokens: 80, // check for this later }, }, }, @@ -934,9 +934,17 @@ describe( const metrics = (responseSpanLog as any).metrics; assert.ok(metrics, "Response span should have metrics"); assert.equal(metrics.prompt_tokens, 100, "Should have prompt_tokens"); - assert.equal(metrics.completion_tokens, 50, "Should have completion_tokens"); + assert.equal( + metrics.completion_tokens, + 50, + "Should have completion_tokens", + ); assert.equal(metrics.tokens, 150, "Should have total tokens"); - assert.equal(metrics.prompt_cached_tokens, 80, "Should extract cached_tokens to prompt_cached_tokens"); + assert.equal( + metrics.prompt_cached_tokens, + 80, + "Should extract cached_tokens to prompt_cached_tokens", + ); }); test("Response span handles zero cached tokens correctly", async () => { @@ -965,7 +973,7 @@ describe( input_tokens: 100, output_tokens: 50, input_tokens_details: { - cached_tokens: 0, // Zero is a valid value + cached_tokens: 0, // Zero is a valid value }, }, }, @@ -977,7 +985,9 @@ describe( await processor.onSpanEnd(responseSpan); const spans = await backgroundLogger.drain(); - const responseSpanLog = spans.find((s: any) => s.span_attributes?.type === "llm"); + const responseSpanLog = spans.find( + (s: any) => s.span_attributes?.type === "llm", + ); const metrics = (responseSpanLog as any).metrics; // Zero should be logged, not skipped @@ -1024,11 +1034,16 @@ describe( await processor.onSpanEnd(responseSpan); const spans = await backgroundLogger.drain(); - const responseSpanLog = spans.find((s: any) => s.span_attributes?.type === "llm"); + const responseSpanLog = spans.find( + (s: any) => s.span_attributes?.type === "llm", + ); const metrics = (responseSpanLog as any).metrics; // Should not have prompt_cached_tokens if not present in usage - assert.isUndefined(metrics.prompt_cached_tokens, "Should not add prompt_cached_tokens if not in usage"); + assert.isUndefined( + metrics.prompt_cached_tokens, + "Should not add prompt_cached_tokens if not in usage", + ); }); test("Generation span extracts cached tokens from usage", async () => { @@ -1060,7 +1075,7 @@ describe( output_tokens: 75, total_tokens: 275, input_tokens_details: { - cached_tokens: 150, // Test Generation span extraction + cached_tokens: 150, // Test Generation span extraction }, }, }, @@ -1080,8 +1095,16 @@ describe( const metrics = (generationSpanLog as any).metrics; assert.ok(metrics, "Generation span should have metrics"); assert.equal(metrics.prompt_tokens, 200, "Should have prompt_tokens"); - assert.equal(metrics.completion_tokens, 75, "Should have completion_tokens"); - assert.equal(metrics.prompt_cached_tokens, 150, "Should extract cached_tokens from Generation span"); + assert.equal( + metrics.completion_tokens, + 75, + "Should have completion_tokens", + ); + assert.equal( + metrics.prompt_cached_tokens, + 150, + "Should extract cached_tokens from Generation span", + ); }); }, ); diff --git a/py/examples/evals/eval_example.py b/py/examples/evals/eval_example.py index 9e7651747..1d605a080 100644 --- a/py/examples/evals/eval_example.py +++ b/py/examples/evals/eval_example.py @@ -1,12 +1,64 @@ +import json + from braintrust import Eval NUM_EXAMPLES = 10 -def exact_match_scorer(input, output, expected): - if expected is None: - return 0.0 - return 1.0 if output == expected else 0.0 +async def exact_match_scorer(input, output, expected, trace=None): + """Async scorer that prints trace spans.""" + score = 0.0 + if expected is not None: + score = 1.0 if output == expected else 0.0 + + if trace: + print("\n" + "="*80) + print(f"🔍 TRACE INFO for input: {input}") + print("="*80) + + # Print trace configuration + config = trace.get_configuration() + print(f"\n📋 Configuration:") + print(f" Object Type: {config.get('objectType')}") + print(f" Object ID: {config.get('objectId')}") + print(f" Root Span: {config.get('rootSpanId')}") + + # Fetch and print spans + try: + spans = await trace.get_spans() + print(f"\n✨ Found {len(spans)} spans:") + print("-"*80) + + for i, span in enumerate(spans, 1): + print(f"\n Span {i}:") + print(f" ID: {span.span_id}") + span_type = span.span_attributes.get('type', 'N/A') if span.span_attributes else 'N/A' + span_name = span.span_attributes.get('name', 'N/A') if span.span_attributes else 'N/A' + print(f" Type: {span_type}") + print(f" Name: {span_name}") + + if span.input: + input_str = json.dumps(span.input) + if len(input_str) > 100: + input_str = input_str[:100] + "..." + print(f" Input: {input_str}") + if span.output: + output_str = json.dumps(span.output) + if len(output_str) > 100: + output_str = output_str[:100] + "..." + print(f" Output: {output_str}") + if span.metadata: + print(f" Metadata: {list(span.metadata.keys())}") + + print("\n" + "="*80 + "\n") + except Exception as e: + print(f"\n⚠️ Error fetching spans: {e}") + import traceback + traceback.print_exc() + else: + print(f"⚠️ No trace available for input: {input}") + + return score def data_fn(): diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index a53ec62e7..43ba78d23 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -1280,6 +1280,29 @@ async def _run_evaluator_internal( filters: list[Filter], stream: Callable[[SSEProgressEvent], None] | None = None, state: BraintrustState | None = None, +): + # Start span cache for this eval (it's disabled by default to avoid temp files outside of evals) + if state is None: + from braintrust.logger import _internal_get_global_state + + state = _internal_get_global_state() + + state.span_cache.start() + try: + return await _run_evaluator_internal_impl(experiment, evaluator, position, filters, stream, state) + finally: + # Clean up disk-based span cache after eval completes and stop caching + state.span_cache.dispose() + state.span_cache.stop() + + +async def _run_evaluator_internal_impl( + experiment, + evaluator: Evaluator, + position: int | None, + filters: list[Filter], + stream: Callable[[SSEProgressEvent], None] | None = None, + state: BraintrustState | None = None, ): event_loop = asyncio.get_event_loop() @@ -1290,11 +1313,13 @@ async def await_or_run_scorer(root_span, scorer, name, **kwargs): {**parent_propagated}, {"span_attributes": {"purpose": "scorer"}}, ) + # Strip trace from logged input - it's internal plumbing that shouldn't appear in spans + logged_input = {k: v for k, v in kwargs.items() if k != "trace"} with root_span.start_span( name=name, span_attributes={"type": SpanTypeAttribute.SCORE, "purpose": "scorer"}, propagated_event=merged_propagated, - input=dict(**kwargs), + input=logged_input, ) as span: score = scorer if hasattr(scorer, "eval_async"): @@ -1415,6 +1440,77 @@ def report_progress(event: TaskProgressEvent): tags = hooks.tags if hooks.tags else None root_span.log(output=output, metadata=metadata, tags=tags) + # Create trace object for scorers + from braintrust.trace import LocalTrace + + async def ensure_spans_flushed(): + # Flush native Braintrust spans + if experiment: + await asyncio.get_event_loop().run_in_executor( + None, lambda: experiment.state.flush() + ) + elif state: + await asyncio.get_event_loop().run_in_executor(None, lambda: state.flush()) + else: + from braintrust.logger import flush as flush_logger + + await asyncio.get_event_loop().run_in_executor(None, flush_logger) + + # Also flush OTEL spans if registered + if state: + await state.flush_otel() + + experiment_id = None + if experiment: + try: + experiment_id = experiment.id + except: + experiment_id = None + + trace = None + if state or experiment: + # Get the state to use + trace_state = state + if not trace_state and experiment: + trace_state = experiment.state + if not trace_state: + # Fall back to global state + from braintrust.logger import _internal_get_global_state + + trace_state = _internal_get_global_state() + + # Access root_span_id from the concrete SpanImpl instance + # The Span interface doesn't expose this but SpanImpl has it + root_span_id_value = getattr(root_span, "root_span_id", root_span.id) + + # Check if there's a parent in the context to determine object_type and object_id + from braintrust.span_identifier_v3 import SpanComponentsV3, span_object_type_v3_to_typed_string + + parent_str = trace_state.current_parent.get() + parent_components = None + if parent_str: + try: + parent_components = SpanComponentsV3.from_str(parent_str) + except Exception: + # If parsing fails, parent_components stays None + pass + + # Determine object_type and object_id based on parent or experiment + if parent_components: + trace_object_type = span_object_type_v3_to_typed_string(parent_components.object_type) + trace_object_id = parent_components.object_id or "" + else: + trace_object_type = "experiment" + trace_object_id = experiment_id or "" + + trace = LocalTrace( + object_type=trace_object_type, + object_id=trace_object_id, + root_span_id=root_span_id_value, + ensure_spans_flushed=ensure_spans_flushed, + state=trace_state, + ) + score_promises = [ asyncio.create_task( await_or_run_scorer( @@ -1426,6 +1522,7 @@ def report_progress(event: TaskProgressEvent): "expected": datum.expected, "metadata": metadata, "output": output, + "trace": trace, }, ) ) diff --git a/py/src/braintrust/functions/invoke.py b/py/src/braintrust/functions/invoke.py index 3aecc3a73..5c566c3f9 100644 --- a/py/src/braintrust/functions/invoke.py +++ b/py/src/braintrust/functions/invoke.py @@ -3,7 +3,7 @@ from sseclient import SSEClient from .._generated_types import FunctionTypeEnum -from ..logger import Exportable, get_span_parent_object, login, proxy_conn +from ..logger import Exportable, _internal_get_global_state, get_span_parent_object, login, proxy_conn from ..util import response_raise_for_status from .constants import INVOKE_API_VERSION from .stream import BraintrustInvokeError, BraintrustStream @@ -243,6 +243,8 @@ def init_function(project_name: str, slug: str, version: str | None = None): :param version: Optional version of the function to use. Defaults to latest. :return: A function that can be used as a task or scorer. """ + # Disable span cache since remote function spans won't be in the local cache + _internal_get_global_state().span_cache.disable() def f(*args: Any, **kwargs: Any) -> Any: if len(args) > 0: diff --git a/py/src/braintrust/functions/test_invoke.py b/py/src/braintrust/functions/test_invoke.py new file mode 100644 index 000000000..c38e2e105 --- /dev/null +++ b/py/src/braintrust/functions/test_invoke.py @@ -0,0 +1,61 @@ +"""Tests for the invoke module, particularly init_function.""" + + +from braintrust.functions.invoke import init_function +from braintrust.logger import _internal_get_global_state, _internal_reset_global_state + + +class TestInitFunction: + """Tests for init_function.""" + + def setup_method(self): + """Reset state before each test.""" + _internal_reset_global_state() + + def teardown_method(self): + """Clean up after each test.""" + _internal_reset_global_state() + + def test_init_function_disables_span_cache(self): + """Test that init_function disables the span cache.""" + state = _internal_get_global_state() + + # Cache should be disabled by default (it's only enabled during evals) + assert state.span_cache.disabled is True + + # Enable the cache (simulating what happens during eval) + state.span_cache.start() + assert state.span_cache.disabled is False + + # Call init_function + f = init_function("test-project", "test-function") + + # Cache should now be disabled (init_function explicitly disables it) + assert state.span_cache.disabled is True + assert f.__name__ == "init_function-test-project-test-function-latest" + + def test_init_function_with_version(self): + """Test that init_function creates a function with the correct name including version.""" + f = init_function("my-project", "my-scorer", version="v1") + assert f.__name__ == "init_function-my-project-my-scorer-v1" + + def test_init_function_without_version_uses_latest(self): + """Test that init_function uses 'latest' in name when version not specified.""" + f = init_function("my-project", "my-scorer") + assert f.__name__ == "init_function-my-project-my-scorer-latest" + + def test_init_function_permanently_disables_cache(self): + """Test that init_function permanently disables the cache (can't be re-enabled).""" + state = _internal_get_global_state() + + # Enable the cache + state.span_cache.start() + assert state.span_cache.disabled is False + + # Call init_function + init_function("test-project", "test-function") + assert state.span_cache.disabled is True + + # Try to start again - should still be disabled because of explicit disable + state.span_cache.start() + assert state.span_cache.disabled is True diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 682298a5a..a717dc5d4 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -401,6 +401,11 @@ def default_get_api_conn(): ), ) + from braintrust.span_cache import SpanCache + + self.span_cache = SpanCache() + self._otel_flush_callback: Any | None = None + def reset_login_info(self): self.app_url: str | None = None self.app_public_url: str | None = None @@ -457,6 +462,21 @@ def context_manager(self): return self._context_manager + def register_otel_flush(self, callback: Any) -> None: + """ + Register an OTEL flush callback. This is called by the OTEL integration + when it initializes a span processor/exporter. + """ + self._otel_flush_callback = callback + + async def flush_otel(self) -> None: + """ + Flush OTEL spans if a callback is registered. + Called during ensure_spans_flushed to ensure OTEL spans are visible in BTQL. + """ + if self._otel_flush_callback: + await self._otel_flush_callback() + def copy_state(self, other: "BraintrustState"): """Copy login information from another BraintrustState instance.""" self.__dict__.update({ @@ -1777,6 +1797,25 @@ def login( _state.login(app_url=app_url, api_key=api_key, org_name=org_name, force_login=force_login) +def register_otel_flush(callback: Any) -> None: + """ + Register a callback to flush OTEL spans. This is called by the OTEL integration + when it initializes a span processor/exporter. + + When ensure_spans_flushed is called (e.g., before a BTQL query in scorers), + this callback will be invoked to ensure OTEL spans are flushed to the server. + + Also disables the span cache, since OTEL spans aren't in the local cache + and we need BTQL to see the complete span tree (both native + OTEL spans). + + :param callback: The async callback function to flush OTEL spans. + """ + global _state + _state.register_otel_flush(callback) + # Disable span cache since OTEL spans aren't in the local cache + _state.span_cache.disable() + + def login_to_state( app_url: str | None = None, api_key: str | None = None, @@ -3847,6 +3886,21 @@ def log_internal(self, event: dict[str, Any] | None = None, internal_data: dict[ if serializable_partial_record.get("metrics", {}).get("end") is not None: self._logged_end_time = serializable_partial_record["metrics"]["end"] + # Write to local span cache for scorer access + # Only cache experiment spans - regular logs don't need caching + if self.parent_object_type == SpanObjectTypeV3.EXPERIMENT: + from braintrust.span_cache import CachedSpan + + cached_span = CachedSpan( + span_id=self.span_id, + input=serializable_partial_record.get("input"), + output=serializable_partial_record.get("output"), + metadata=serializable_partial_record.get("metadata"), + span_parents=self.span_parents, + span_attributes=serializable_partial_record.get("span_attributes"), + ) + self.state.span_cache.queue_write(self.root_span_id, self.span_id, cached_span) + def compute_record() -> dict[str, Any]: exporter = _get_exporter() return dict( diff --git a/py/src/braintrust/span_cache.py b/py/src/braintrust/span_cache.py new file mode 100644 index 000000000..17148cdec --- /dev/null +++ b/py/src/braintrust/span_cache.py @@ -0,0 +1,337 @@ +""" +SpanCache provides a disk-based cache for span data, allowing +scorers to read spans without making server round-trips when possible. + +Spans are stored on disk to minimize memory usage during evaluations. +The cache file is automatically cleaned up when dispose() is called. +""" + +import atexit +import json +import os +import tempfile +import uuid +from typing import Any, Optional + +from braintrust.util import merge_dicts + +# Global registry of active span caches for process exit cleanup +_active_caches: set["SpanCache"] = set() +_exit_handlers_registered = False + + +class CachedSpan: + """Cached span data structure.""" + + def __init__( + self, + span_id: str, + input: Optional[Any] = None, + output: Optional[Any] = None, + metadata: Optional[dict[str, Any]] = None, + span_parents: Optional[list[str]] = None, + span_attributes: Optional[dict[str, Any]] = None, + ): + self.span_id = span_id + self.input = input + self.output = output + self.metadata = metadata + self.span_parents = span_parents + self.span_attributes = span_attributes + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + result = {"span_id": self.span_id} + if self.input is not None: + result["input"] = self.input + if self.output is not None: + result["output"] = self.output + if self.metadata is not None: + result["metadata"] = self.metadata + if self.span_parents is not None: + result["span_parents"] = self.span_parents + if self.span_attributes is not None: + result["span_attributes"] = self.span_attributes + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "CachedSpan": + """Create from dictionary.""" + return cls( + span_id=data["span_id"], + input=data.get("input"), + output=data.get("output"), + metadata=data.get("metadata"), + span_parents=data.get("span_parents"), + span_attributes=data.get("span_attributes"), + ) + + +class DiskSpanRecord: + """Record structure for disk storage.""" + + def __init__(self, root_span_id: str, span_id: str, data: CachedSpan): + self.root_span_id = root_span_id + self.span_id = span_id + self.data = data + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "rootSpanId": self.root_span_id, + "spanId": self.span_id, + "data": self.data.to_dict(), + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DiskSpanRecord": + """Create from dictionary.""" + return cls( + root_span_id=data["rootSpanId"], + span_id=data["spanId"], + data=CachedSpan.from_dict(data["data"]), + ) + + +class SpanCache: + """ + Disk-based cache for span data, keyed by rootSpanId. + + This cache writes spans to a temporary file to minimize memory usage. + It uses append-only writes and reads the full file when querying. + """ + + def __init__(self, disabled: bool = False): + self._cache_file_path: Optional[str] = None + self._initialized = False + # Tracks whether the cache was explicitly disabled (via constructor or disable()) + self._explicitly_disabled = disabled + # Tracks whether the cache has been enabled (for evals only) + self._enabled = False + # Reference count of active evals using this cache + self._active_eval_count = 0 + # Small in-memory index tracking which rootSpanIds have data + self._root_span_index: set[str] = set() + # Buffer for pending writes + self._write_buffer: list[DiskSpanRecord] = [] + + def disable(self) -> None: + """ + Disable the cache at runtime. This is called automatically when + OTEL is registered, since OTEL spans won't be in the cache. + """ + self._explicitly_disabled = True + + def start(self) -> None: + """ + Start caching spans for use during evaluations. + This only starts caching if the cache wasn't permanently disabled. + Called by Eval() to turn on caching for the duration of the eval. + Uses reference counting to support parallel evals. + """ + if not self._explicitly_disabled: + self._enabled = True + self._active_eval_count += 1 + + def stop(self) -> None: + """ + Stop caching spans and return to the default disabled state. + Unlike disable(), this allows start() to work again for future evals. + Called after an eval completes to return to the default state. + Uses reference counting - only disables when all evals are complete. + """ + self._active_eval_count -= 1 + if self._active_eval_count <= 0: + self._active_eval_count = 0 + self._enabled = False + + @property + def disabled(self) -> bool: + """Check if cache is disabled.""" + return self._explicitly_disabled or not self._enabled + + def _ensure_initialized(self) -> None: + """Initialize the cache file if not already done.""" + if self.disabled or self._initialized: + return + + try: + # Create temporary file + unique_id = f"{int(os.times().elapsed * 1000000)}-{uuid.uuid4().hex[:8]}" + self._cache_file_path = os.path.join(tempfile.gettempdir(), f"braintrust-span-cache-{unique_id}.jsonl") + + # Create the file + with open(self._cache_file_path, "w") as f: + pass + + self._initialized = True + self._register_exit_handler() + except Exception: + # Silently fail if filesystem is unavailable - cache is best-effort + # This can happen if temp directory is not writable or disk is full + self._explicitly_disabled = True + return + + def _register_exit_handler(self) -> None: + """Register a handler to clean up the temp file on process exit.""" + global _exit_handlers_registered + _active_caches.add(self) + + if not _exit_handlers_registered: + _exit_handlers_registered = True + + def cleanup_all_caches(): + """Clean up all active caches.""" + for cache in _active_caches: + if cache._cache_file_path and os.path.exists(cache._cache_file_path): + try: + os.unlink(cache._cache_file_path) + except Exception: + # Ignore cleanup errors - file might not exist or already deleted + pass + + atexit.register(cleanup_all_caches) + + def queue_write(self, root_span_id: str, span_id: str, data: CachedSpan) -> None: + """ + Write a span to the cache. + In Python, we write synchronously (no async queue like in TS). + """ + if self.disabled: + return + + self._ensure_initialized() + + record = DiskSpanRecord(root_span_id, span_id, data) + self._write_buffer.append(record) + self._root_span_index.add(root_span_id) + + # Write to disk immediately (simplified compared to TS async version) + self._flush_write_buffer() + + def _flush_write_buffer(self) -> None: + """Flush the write buffer to disk.""" + if not self._write_buffer or not self._cache_file_path: + return + + try: + with open(self._cache_file_path, "a") as f: + for record in self._write_buffer: + f.write(json.dumps(record.to_dict()) + "\n") + self._write_buffer.clear() + except Exception: + # Silently fail if write fails - cache is best-effort + # This can happen if disk is full or file permissions changed + pass + + def get_by_root_span_id(self, root_span_id: str) -> Optional[list[CachedSpan]]: + """ + Get all cached spans for a given rootSpanId. + + This reads the file and merges all records for the given rootSpanId. + + Args: + root_span_id: The root span ID to look up + + Returns: + List of cached spans, or None if not in cache + """ + if self.disabled: + return None + + # Quick check using in-memory index + if root_span_id not in self._root_span_index: + return None + + # Accumulate spans by spanId, merging updates + span_map: dict[str, dict[str, Any]] = {} + + # Read from disk if initialized + if self._initialized and self._cache_file_path and os.path.exists(self._cache_file_path): + try: + with open(self._cache_file_path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + record_dict = json.loads(line) + record = DiskSpanRecord.from_dict(record_dict) + if record.root_span_id != root_span_id: + continue + + if record.span_id in span_map: + merge_dicts(span_map[record.span_id], record.data.to_dict()) + else: + span_map[record.span_id] = record.data.to_dict() + except Exception: + # Skip malformed lines - may occur if file was corrupted or truncated + pass + except Exception: + # Continue to check buffer even if disk read fails + # This can happen if file was deleted or permissions changed + pass + + # Also check the in-memory write buffer for unflushed data + for record in self._write_buffer: + if record.root_span_id != root_span_id: + continue + if record.span_id in span_map: + merge_dicts(span_map[record.span_id], record.data.to_dict()) + else: + span_map[record.span_id] = record.data.to_dict() + + if not span_map: + return None + + return [CachedSpan.from_dict(data) for data in span_map.values()] + + def has(self, root_span_id: str) -> bool: + """Check if a rootSpanId has cached data.""" + if self.disabled: + return False + return root_span_id in self._root_span_index + + def clear(self, root_span_id: str) -> None: + """ + Clear all cached spans for a given rootSpanId. + Note: This only removes from the index. The data remains in the file + but will be ignored on reads. + """ + self._root_span_index.discard(root_span_id) + + def clear_all(self) -> None: + """Clear all cached data and remove the cache file.""" + self._root_span_index.clear() + self.dispose() + + @property + def size(self) -> int: + """Get the number of root spans currently tracked.""" + return len(self._root_span_index) + + def dispose(self) -> None: + """ + Clean up the cache file. Call this when the eval is complete. + Only performs cleanup when all active evals have completed (refcount = 0). + """ + # Only dispose if no active evals are using this cache + if self._active_eval_count > 0: + return + + # Remove from global registry + _active_caches.discard(self) + + # Clear pending writes + self._write_buffer.clear() + + if self._cache_file_path and os.path.exists(self._cache_file_path): + try: + os.unlink(self._cache_file_path) + except Exception: + # Ignore cleanup errors - file might not exist or already deleted + pass + self._cache_file_path = None + + self._initialized = False + self._root_span_index.clear() diff --git a/py/src/braintrust/span_identifier_v3.py b/py/src/braintrust/span_identifier_v3.py index ea850dcbc..d86903153 100644 --- a/py/src/braintrust/span_identifier_v3.py +++ b/py/src/braintrust/span_identifier_v3.py @@ -38,6 +38,27 @@ def __str__(self): }[self] +def span_object_type_v3_to_typed_string( + object_type: SpanObjectTypeV3, +) -> str: + """Convert SpanObjectTypeV3 enum to typed string literal. + + Args: + object_type: The SpanObjectTypeV3 enum value + + Returns: + One of "experiment", "project_logs", or "playground_logs" + """ + if object_type == SpanObjectTypeV3.EXPERIMENT: + return "experiment" + elif object_type == SpanObjectTypeV3.PROJECT_LOGS: + return "project_logs" + elif object_type == SpanObjectTypeV3.PLAYGROUND_LOGS: + return "playground_logs" + else: + raise ValueError(f"Unknown SpanObjectTypeV3: {object_type}") + + class InternalSpanComponentUUIDFields(Enum): OBJECT_ID = 1 ROW_ID = 2 diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 634cf1261..afac5c199 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -2461,6 +2461,95 @@ def test_logger_export_respects_otel_compat_enabled(): assert version == 4, f"Expected V4 encoding (version=4), got version={version}" +def test_register_otel_flush_callback(): + """Test that register_otel_flush registers a callback correctly.""" + import asyncio + + from braintrust import register_otel_flush + from braintrust.logger import _internal_get_global_state + from braintrust.test_helpers import init_test_logger + + init_test_logger(__name__) + state = _internal_get_global_state() + + # Track if callback was invoked + callback_invoked = False + + async def mock_flush(): + nonlocal callback_invoked + callback_invoked = True + + # Register the callback + register_otel_flush(mock_flush) + + # Calling flush_otel should invoke the registered callback + asyncio.run(state.flush_otel()) + + assert callback_invoked is True + + +def test_register_otel_flush_disables_span_cache(): + """Test that register_otel_flush disables the span cache.""" + from braintrust import register_otel_flush + from braintrust.logger import _internal_get_global_state + from braintrust.test_helpers import init_test_logger + + init_test_logger(__name__) + state = _internal_get_global_state() + + # Enable the cache (simulating what happens during eval) + state.span_cache.start() + assert state.span_cache.disabled is False + + async def mock_flush(): + pass + + # Register OTEL flush + register_otel_flush(mock_flush) + + # Cache should now be disabled + assert state.span_cache.disabled is True + + +def test_flush_otel_noop_when_no_callback(): + """Test that flush_otel is a no-op when no callback is registered.""" + import asyncio + + from braintrust.logger import _internal_get_global_state + from braintrust.test_helpers import init_test_logger + + init_test_logger(__name__) + state = _internal_get_global_state() + + # Should not throw even with no callback registered + asyncio.run(state.flush_otel()) + + +def test_register_otel_flush_permanently_disables_cache(): + """Test that register_otel_flush permanently disables the cache.""" + from braintrust import register_otel_flush + from braintrust.logger import _internal_get_global_state + from braintrust.test_helpers import init_test_logger + + init_test_logger(__name__) + state = _internal_get_global_state() + + # Enable the cache + state.span_cache.start() + assert state.span_cache.disabled is False + + async def mock_flush(): + pass + + # Register OTEL flush + register_otel_flush(mock_flush) + assert state.span_cache.disabled is True + + # Try to start again - should still be disabled because of explicit disable + state.span_cache.start() + assert state.span_cache.disabled is True + + class TestJSONAttachment(TestCase): def test_create_attachment_from_json_data(self): """Test creating an attachment from JSON data.""" diff --git a/py/src/braintrust/test_span_cache.py b/py/src/braintrust/test_span_cache.py new file mode 100644 index 000000000..fc0b6c7ef --- /dev/null +++ b/py/src/braintrust/test_span_cache.py @@ -0,0 +1,344 @@ +"""Tests for SpanCache (disk-based cache).""" + + +from braintrust.span_cache import CachedSpan, SpanCache + + +def test_span_cache_write_and_read(): + """Test storing and retrieving spans by rootSpanId.""" + cache = SpanCache() + cache.start() # Start for testing (cache is disabled by default) + + root_span_id = "root-123" + span1 = CachedSpan( + span_id="span-1", + input={"text": "hello"}, + output={"response": "world"}, + ) + span2 = CachedSpan( + span_id="span-2", + input={"text": "foo"}, + output={"response": "bar"}, + ) + + cache.queue_write(root_span_id, span1.span_id, span1) + cache.queue_write(root_span_id, span2.span_id, span2) + + spans = cache.get_by_root_span_id(root_span_id) + assert spans is not None + assert len(spans) == 2 + + span_ids = {s.span_id for s in spans} + assert "span-1" in span_ids + assert "span-2" in span_ids + + cache.stop() + cache.dispose() + + +def test_span_cache_return_none_for_unknown(): + """Test that unknown rootSpanId returns None.""" + cache = SpanCache() + cache.start() + + spans = cache.get_by_root_span_id("nonexistent") + assert spans is None + + cache.stop() + cache.dispose() + + +def test_span_cache_merge_on_duplicate_writes(): + """Test that subsequent writes to same spanId merge data.""" + cache = SpanCache() + cache.start() + + root_span_id = "root-123" + span_id = "span-1" + + cache.queue_write( + root_span_id, + span_id, + CachedSpan(span_id=span_id, input={"text": "hello"}), + ) + + cache.queue_write( + root_span_id, + span_id, + CachedSpan(span_id=span_id, output={"response": "world"}), + ) + + spans = cache.get_by_root_span_id(root_span_id) + assert spans is not None + assert len(spans) == 1 + assert spans[0].span_id == span_id + assert spans[0].input == {"text": "hello"} + assert spans[0].output == {"response": "world"} + + cache.stop() + cache.dispose() + + +def test_span_cache_merge_metadata(): + """Test that metadata objects are merged.""" + cache = SpanCache() + cache.start() + + root_span_id = "root-123" + span_id = "span-1" + + cache.queue_write( + root_span_id, + span_id, + CachedSpan(span_id=span_id, metadata={"key1": "value1"}), + ) + + cache.queue_write( + root_span_id, + span_id, + CachedSpan(span_id=span_id, metadata={"key2": "value2"}), + ) + + spans = cache.get_by_root_span_id(root_span_id) + assert spans is not None + assert spans[0].metadata == {"key1": "value1", "key2": "value2"} + + cache.stop() + cache.dispose() + + +def test_span_cache_has(): + """Test the has() method.""" + cache = SpanCache() + cache.start() + + cache.queue_write("root-123", "span-1", CachedSpan(span_id="span-1")) + assert cache.has("root-123") is True + assert cache.has("nonexistent") is False + + cache.stop() + cache.dispose() + + +def test_span_cache_clear(): + """Test clearing spans for a specific rootSpanId.""" + cache = SpanCache() + cache.start() + + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + cache.queue_write("root-2", "span-2", CachedSpan(span_id="span-2")) + + cache.clear("root-1") + + assert cache.has("root-1") is False + assert cache.has("root-2") is True + + cache.stop() + cache.dispose() + + +def test_span_cache_clear_all(): + """Test clearing all cached spans.""" + cache = SpanCache() + cache.start() + + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + cache.queue_write("root-2", "span-2", CachedSpan(span_id="span-2")) + + cache.clear_all() + + assert cache.size == 0 + + cache.stop() + cache.dispose() + + +def test_span_cache_size(): + """Test the size property.""" + cache = SpanCache() + cache.start() + + assert cache.size == 0 + + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 1 + + cache.queue_write("root-1", "span-2", CachedSpan(span_id="span-2")) # Same root + assert cache.size == 1 + + cache.queue_write("root-2", "span-3", CachedSpan(span_id="span-3")) # Different root + assert cache.size == 2 + + cache.stop() + cache.dispose() + + +def test_span_cache_dispose(): + """Test that dispose cleans up and allows reuse.""" + cache = SpanCache() + cache.start() + + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 1 + + # Stop first to decrement refcount, then dispose + cache.stop() + cache.dispose() + + assert cache.size == 0 + assert cache.has("root-1") is False + + # Should be able to write again after dispose (if we start again) + cache.start() + cache.queue_write("root-2", "span-2", CachedSpan(span_id="span-2")) + assert cache.size == 1 + + cache.stop() + cache.dispose() + + +def test_span_cache_disable(): + """Test that disable() prevents writes.""" + cache = SpanCache() + cache.start() + + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 1 + + cache.disable() + + # Writes after disable should be no-ops + cache.queue_write("root-2", "span-2", CachedSpan(span_id="span-2")) + assert cache.size == 1 # Still 1, not 2 + + cache.stop() + cache.dispose() + + +def test_span_cache_disabled_getter(): + """Test the disabled property.""" + # Cache is disabled by default until start() is called + cache = SpanCache() + assert cache.disabled is True + + cache.start() + assert cache.disabled is False + + cache.disable() + assert cache.disabled is True + + cache.dispose() + + +def test_span_cache_disabled_from_constructor(): + """Test that cache can be disabled via constructor.""" + cache = SpanCache(disabled=True) + assert cache.disabled is True + + # Writes should be no-ops + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 0 + assert cache.get_by_root_span_id("root-1") is None + + cache.dispose() + + +def test_span_cache_start_stop_lifecycle(): + """Test that stop() allows start() to work again.""" + cache = SpanCache() + + # Initially disabled by default + assert cache.disabled is True + + # Start for first "eval" + cache.start() + assert cache.disabled is False + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 1 + + # Stop after first "eval" + cache.stop() + cache.dispose() + assert cache.disabled is True + + # Start for second "eval" - should work! + cache.start() + assert cache.disabled is False + cache.queue_write("root-2", "span-2", CachedSpan(span_id="span-2")) + assert cache.size == 1 + + cache.stop() + cache.dispose() + + +def test_span_cache_disable_prevents_start(): + """Test that disable() prevents start() from working.""" + cache = SpanCache() + + # Simulate disable being called + cache.disable() + assert cache.disabled is True + + # start() should be a no-op after disable() + cache.start() + assert cache.disabled is True + + # Writes should still be no-ops + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 0 + + cache.dispose() + + +def test_span_cache_parallel_eval_refcount(): + """Test reference counting for parallel evals.""" + cache = SpanCache() + + # Simulate two evals starting + cache.start() # Eval 1 + assert cache.disabled is False + + cache.start() # Eval 2 + assert cache.disabled is False + + # Write data from both evals + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + cache.queue_write("root-2", "span-2", CachedSpan(span_id="span-2")) + assert cache.size == 2 + + # Eval 1 finishes first + cache.dispose() # Should NOT dispose (refcount = 2) + cache.stop() # Decrements to 1 + + # Cache should still be enabled and data intact + assert cache.disabled is False + assert cache.size == 2 + assert cache.get_by_root_span_id("root-1") is not None + assert cache.get_by_root_span_id("root-2") is not None + + # Eval 2 finishes + cache.dispose() # Should NOT dispose yet (refcount = 1) + cache.stop() # Decrements to 0, disables cache + + # Now cache should be disabled + assert cache.disabled is True + + # Final dispose should now work + cache.dispose() # NOW it disposes (refcount = 0) + assert cache.size == 0 + + +def test_span_cache_refcount_underflow(): + """Test that refcount handles underflow gracefully.""" + cache = SpanCache() + + # Call stop without start + cache.stop() + + # Should work normally after + cache.start() + cache.queue_write("root-1", "span-1", CachedSpan(span_id="span-1")) + assert cache.size == 1 + + cache.stop() + cache.dispose() diff --git a/py/src/braintrust/test_trace.py b/py/src/braintrust/test_trace.py new file mode 100644 index 000000000..e2de657d6 --- /dev/null +++ b/py/src/braintrust/test_trace.py @@ -0,0 +1,267 @@ +"""Tests for Trace functionality.""" + +import pytest +from braintrust.trace import CachedSpanFetcher, SpanData + + +# Helper to create mock spans +def make_span(span_id: str, span_type: str, **extra) -> SpanData: + return SpanData( + span_id=span_id, + input={"text": f"input-{span_id}"}, + output={"text": f"output-{span_id}"}, + span_attributes={"type": span_type}, + **extra, + ) + + +class TestCachedSpanFetcher: + """Test CachedSpanFetcher caching behavior.""" + + @pytest.mark.asyncio + async def test_fetch_all_spans_without_filter(self): + """Test fetching all spans when no filter specified.""" + mock_spans = [ + make_span("span-1", "llm"), + make_span("span-2", "function"), + make_span("span-3", "llm"), + ] + + call_count = 0 + + async def fetch_fn(span_type): + nonlocal call_count + call_count += 1 + return mock_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + result = await fetcher.get_spans() + + assert call_count == 1 + assert len(result) == 3 + assert {s.span_id for s in result} == {"span-1", "span-2", "span-3"} + + @pytest.mark.asyncio + async def test_fetch_specific_span_types(self): + """Test fetching specific span types when filter specified.""" + llm_spans = [make_span("span-1", "llm"), make_span("span-2", "llm")] + + call_count = 0 + + async def fetch_fn(span_type): + nonlocal call_count + call_count += 1 + assert span_type == ["llm"] + return llm_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + result = await fetcher.get_spans(span_type=["llm"]) + + assert call_count == 1 + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_return_cached_spans_after_fetching_all(self): + """Test that cached spans are returned without re-fetching after fetching all.""" + mock_spans = [ + make_span("span-1", "llm"), + make_span("span-2", "function"), + ] + + call_count = 0 + + async def fetch_fn(span_type): + nonlocal call_count + call_count += 1 + return mock_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # First call - fetches + await fetcher.get_spans() + assert call_count == 1 + + # Second call - should use cache + result = await fetcher.get_spans() + assert call_count == 1 # Still 1 + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_return_cached_spans_for_previously_fetched_types(self): + """Test that previously fetched types are returned from cache.""" + llm_spans = [make_span("span-1", "llm"), make_span("span-2", "llm")] + + call_count = 0 + + async def fetch_fn(span_type): + nonlocal call_count + call_count += 1 + return llm_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # First call - fetches llm spans + await fetcher.get_spans(span_type=["llm"]) + assert call_count == 1 + + # Second call for same type - should use cache + result = await fetcher.get_spans(span_type=["llm"]) + assert call_count == 1 # Still 1 + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_only_fetch_missing_span_types(self): + """Test that only missing span types are fetched.""" + llm_spans = [make_span("span-1", "llm")] + function_spans = [make_span("span-2", "function")] + + call_count = 0 + + async def fetch_fn(span_type): + nonlocal call_count + call_count += 1 + if span_type == ["llm"]: + return llm_spans + elif span_type == ["function"]: + return function_spans + return [] + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # First call - fetches llm spans + await fetcher.get_spans(span_type=["llm"]) + assert call_count == 1 + + # Second call for both types - should only fetch function + result = await fetcher.get_spans(span_type=["llm", "function"]) + assert call_count == 2 + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_no_refetch_after_fetching_all_spans(self): + """Test that no re-fetching occurs after fetching all spans.""" + all_spans = [ + make_span("span-1", "llm"), + make_span("span-2", "function"), + make_span("span-3", "tool"), + ] + + call_count = 0 + + async def fetch_fn(span_type): + nonlocal call_count + call_count += 1 + return all_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # Fetch all spans + await fetcher.get_spans() + assert call_count == 1 + + # Subsequent filtered calls should use cache + llm_result = await fetcher.get_spans(span_type=["llm"]) + assert call_count == 1 # Still 1 + assert len(llm_result) == 1 + assert llm_result[0].span_id == "span-1" + + function_result = await fetcher.get_spans(span_type=["function"]) + assert call_count == 1 # Still 1 + assert len(function_result) == 1 + assert function_result[0].span_id == "span-2" + + @pytest.mark.asyncio + async def test_filter_by_multiple_span_types_from_cache(self): + """Test filtering by multiple span types from cache.""" + all_spans = [ + make_span("span-1", "llm"), + make_span("span-2", "function"), + make_span("span-3", "tool"), + make_span("span-4", "llm"), + ] + + async def fetch_fn(span_type): + return all_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # Fetch all first + await fetcher.get_spans() + + # Filter for llm and tool + result = await fetcher.get_spans(span_type=["llm", "tool"]) + assert len(result) == 3 + assert {s.span_id for s in result} == {"span-1", "span-3", "span-4"} + + @pytest.mark.asyncio + async def test_return_empty_for_nonexistent_span_type(self): + """Test that empty array is returned for non-existent span type.""" + all_spans = [make_span("span-1", "llm")] + + async def fetch_fn(span_type): + return all_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # Fetch all first + await fetcher.get_spans() + + # Query for non-existent type + result = await fetcher.get_spans(span_type=["nonexistent"]) + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_handle_spans_with_no_type(self): + """Test handling spans without type (empty string type).""" + spans = [ + make_span("span-1", "llm"), + SpanData(span_id="span-2", input={}, span_attributes={}), # No type + SpanData(span_id="span-3", input={}), # No span_attributes + ] + + async def fetch_fn(span_type): + return spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + # Fetch all + result = await fetcher.get_spans() + assert len(result) == 3 + + # Spans without type go into "" bucket + no_type_result = await fetcher.get_spans(span_type=[""]) + assert len(no_type_result) == 2 + + @pytest.mark.asyncio + async def test_handle_empty_results(self): + """Test handling empty results.""" + + async def fetch_fn(span_type): + return [] + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + result = await fetcher.get_spans() + assert len(result) == 0 + + # Should still mark as fetched + await fetcher.get_spans(span_type=["llm"]) + # No additional assertions, just making sure it doesn't crash + + @pytest.mark.asyncio + async def test_handle_empty_span_type_array(self): + """Test that empty spanType array is handled same as undefined.""" + mock_spans = [make_span("span-1", "llm")] + + call_args = [] + + async def fetch_fn(span_type): + call_args.append(span_type) + return mock_spans + + fetcher = CachedSpanFetcher(fetch_fn=fetch_fn) + + result = await fetcher.get_spans(span_type=[]) + + assert call_args[0] is None or call_args[0] == [] + assert len(result) == 1 diff --git a/py/src/braintrust/trace.py b/py/src/braintrust/trace.py new file mode 100644 index 000000000..57344ee16 --- /dev/null +++ b/py/src/braintrust/trace.py @@ -0,0 +1,385 @@ +""" +Trace objects for accessing spans in evaluations. + +This module provides the LocalTrace class which allows scorers to access +spans from the current evaluation task without making server round-trips. +""" + +import asyncio +from typing import Any, Awaitable, Callable, Optional, Protocol + +from braintrust.logger import BraintrustState, ObjectFetcher + + +class SpanData: + """Span data returned by get_spans().""" + + def __init__( + self, + input: Optional[Any] = None, + output: Optional[Any] = None, + metadata: Optional[dict[str, Any]] = None, + span_id: Optional[str] = None, + span_parents: Optional[list[str]] = None, + span_attributes: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + self.input = input + self.output = output + self.metadata = metadata + self.span_id = span_id + self.span_parents = span_parents + self.span_attributes = span_attributes + # Store any additional fields + for key, value in kwargs.items(): + setattr(self, key, value) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SpanData": + """Create SpanData from a dictionary.""" + return cls(**data) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + result = {} + for key, value in self.__dict__.items(): + if value is not None: + result[key] = value + return result + + +class SpanFetcher(ObjectFetcher[dict[str, Any]]): + """ + Fetcher for spans by root_span_id, using the ObjectFetcher pattern. + Handles pagination automatically via cursor-based iteration. + """ + + def __init__( + self, + object_type: str, # Literal["experiment", "project_logs", "playground_logs"] + object_id: str, + root_span_id: str, + state: BraintrustState, + span_type_filter: Optional[list[str]] = None, + ): + # Build the filter expression for root_span_id and optionally span_attributes.type + filter_expr = self._build_filter(root_span_id, span_type_filter) + + super().__init__( + object_type=object_type, + _internal_btql={"filter": filter_expr}, + ) + self._object_id = object_id + self._state = state + + @staticmethod + def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]: + """Build BTQL filter expression.""" + children = [ + # Base filter: root_span_id = 'value' + { + "op": "eq", + "left": {"op": "ident", "name": ["root_span_id"]}, + "right": {"op": "literal", "value": root_span_id}, + }, + # Exclude span_attributes.purpose = 'scorer' + { + "op": "or", + "children": [ + { + "op": "isnull", + "expr": {"op": "ident", "name": ["span_attributes", "purpose"]}, + }, + { + "op": "ne", + "left": {"op": "ident", "name": ["span_attributes", "purpose"]}, + "right": {"op": "literal", "value": "scorer"}, + }, + ], + }, + ] + + # If span type filter specified, add it + if span_type_filter and len(span_type_filter) > 0: + children.append( + { + "op": "in", + "left": {"op": "ident", "name": ["span_attributes", "type"]}, + "right": {"op": "literal", "value": span_type_filter}, + } + ) + + return {"op": "and", "children": children} + + @property + def id(self) -> str: + return self._object_id + + def _get_state(self) -> BraintrustState: + return self._state + + +SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]] + + +class CachedSpanFetcher: + """ + Cached span fetcher that handles fetching and caching spans by type. + + Caching strategy: + - Cache spans by span type (dict[spanType, list[SpanData]]) + - Track if all spans have been fetched (all_fetched flag) + - When filtering by spanType, only fetch types not already in cache + """ + + def __init__( + self, + object_type: Optional[str] = None, # Literal["experiment", "project_logs", "playground_logs"] + object_id: Optional[str] = None, + root_span_id: Optional[str] = None, + get_state: Optional[Callable[[], Awaitable[BraintrustState]]] = None, + fetch_fn: Optional[SpanFetchFn] = None, + ): + self._span_cache: dict[str, list[SpanData]] = {} + self._all_fetched = False + + if fetch_fn is not None: + # Direct fetch function injection (for testing) + self._fetch_fn = fetch_fn + else: + # Standard constructor with SpanFetcher + if object_type is None or object_id is None or root_span_id is None or get_state is None: + raise ValueError("Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state") + + async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]: + state = await get_state() + fetcher = SpanFetcher( + object_type=object_type, + object_id=object_id, + root_span_id=root_span_id, + state=state, + span_type_filter=span_type, + ) + rows = list(fetcher.fetch()) + # Filter out scorer spans + filtered = [ + row + for row in rows + if not ( + isinstance(row.get("span_attributes"), dict) + and row.get("span_attributes", {}).get("purpose") == "scorer" + ) + ] + return [ + SpanData( + input=row.get("input"), + output=row.get("output"), + metadata=row.get("metadata"), + span_id=row.get("span_id"), + span_parents=row.get("span_parents"), + span_attributes=row.get("span_attributes"), + id=row.get("id"), + _xact_id=row.get("_xact_id"), + _pagination_key=row.get("_pagination_key"), + root_span_id=row.get("root_span_id"), + ) + for row in filtered + ] + + self._fetch_fn = _fetch_fn + + async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + """ + Get spans, using cache when possible. + + Args: + span_type: Optional list of span types to filter by + + Returns: + List of matching spans + """ + # If we've fetched all spans, just filter from cache + if self._all_fetched: + return self._get_from_cache(span_type) + + # If no filter requested, fetch everything + if not span_type or len(span_type) == 0: + await self._fetch_spans(None) + self._all_fetched = True + return self._get_from_cache(None) + + # Find which spanTypes we don't have in cache yet + missing_types = [t for t in span_type if t not in self._span_cache] + + # If all requested types are cached, return from cache + if not missing_types: + return self._get_from_cache(span_type) + + # Fetch only the missing types + await self._fetch_spans(missing_types) + return self._get_from_cache(span_type) + + async def _fetch_spans(self, span_type: Optional[list[str]]) -> None: + """Fetch spans from the server.""" + spans = await self._fetch_fn(span_type) + + for span in spans: + span_attrs = span.span_attributes or {} + span_type_str = span_attrs.get("type", "") + if span_type_str not in self._span_cache: + self._span_cache[span_type_str] = [] + self._span_cache[span_type_str].append(span) + + def _get_from_cache(self, span_type: Optional[list[str]]) -> list[SpanData]: + """Get spans from cache, optionally filtering by type.""" + if not span_type or len(span_type) == 0: + # Return all spans + result = [] + for spans in self._span_cache.values(): + result.extend(spans) + return result + + # Return only requested types + result = [] + for type_str in span_type: + if type_str in self._span_cache: + result.extend(self._span_cache[type_str]) + return result + + +class Trace(Protocol): + """ + Interface for trace objects that can be used by scorers. + Both the SDK's LocalTrace class and the API wrapper's WrapperTrace implement this. + """ + + def get_configuration(self) -> dict[str, str]: + """Get the trace configuration (object_type, object_id, root_span_id).""" + ... + + async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + """ + Fetch all spans for this root span. + + Args: + span_type: Optional list of span types to filter by + + Returns: + List of matching spans + """ + ... + + +class LocalTrace(dict): + """ + SDK implementation of Trace that uses local span cache and falls back to BTQL. + Carries identifying information about the evaluation so scorers can perform + richer logging or side effects. + + Inherits from dict so that it serializes to {"trace_ref": {...}} when passed + to json.dumps(). This allows LocalTrace to be transparently serialized when + passed through invoke() or other JSON-serializing code paths. + """ + + def __init__( + self, + object_type: str, # Literal["experiment", "project_logs", "playground_logs"] + object_id: str, + root_span_id: str, + ensure_spans_flushed: Optional[Callable[[], Awaitable[None]]], + state: BraintrustState, + ): + # Initialize dict with trace_ref for JSON serialization + super().__init__({ + "trace_ref": { + "object_type": object_type, + "object_id": object_id, + "root_span_id": root_span_id, + } + }) + + self._object_type = object_type + self._object_id = object_id + self._root_span_id = root_span_id + self._ensure_spans_flushed = ensure_spans_flushed + self._state = state + self._spans_flushed = False + self._spans_flush_promise: Optional[asyncio.Task[None]] = None + + async def get_state() -> BraintrustState: + await self._ensure_spans_ready() + # Ensure state is logged in + await asyncio.get_event_loop().run_in_executor(None, lambda: state.login()) + return state + + self._cached_fetcher = CachedSpanFetcher( + object_type=object_type, + object_id=object_id, + root_span_id=root_span_id, + get_state=get_state, + ) + + def get_configuration(self) -> dict[str, str]: + """Get the trace configuration.""" + return { + "object_type": self._object_type, + "object_id": self._object_id, + "root_span_id": self._root_span_id, + } + + async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + """ + Fetch all rows for this root span from its parent object (experiment or project logs). + First checks the local span cache for recently logged spans, then falls + back to CachedSpanFetcher which handles BTQL fetching and caching. + + Args: + span_type: Optional list of span types to filter by + + Returns: + List of matching spans + """ + # Try local span cache first (for recently logged spans not yet flushed) + cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id) + if cached_spans and len(cached_spans) > 0: + # Filter by purpose + spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"] + + # Filter by span type if requested + if span_type and len(span_type) > 0: + spans = [span for span in spans if (span.span_attributes or {}).get("type", "") in span_type] + + # Convert to SpanData + return [ + SpanData( + input=span.input, + output=span.output, + metadata=span.metadata, + span_id=span.span_id, + span_parents=span.span_parents, + span_attributes=span.span_attributes, + ) + for span in spans + ] + + # Fall back to CachedSpanFetcher for BTQL fetching with caching + return await self._cached_fetcher.get_spans(span_type) + + async def _ensure_spans_ready(self) -> None: + """Ensure spans are flushed before fetching.""" + if self._spans_flushed or not self._ensure_spans_flushed: + return + + if self._spans_flush_promise is None: + + async def flush_and_mark(): + try: + await self._ensure_spans_flushed() + self._spans_flushed = True + except Exception as err: + self._spans_flush_promise = None + raise err + + self._spans_flush_promise = asyncio.create_task(flush_and_mark()) + + await self._spans_flush_promise