diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 4aac7dc7d..ef38546f7 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -83,6 +83,9 @@ class EvalCase(SerializableDataClass, Generic[Input, Output]): _xact_id: str | None = None created: str | None = None + # The number of times to run the evaluator for this specific input. + trial_count: int | None = None + class _EvalCaseDictNoOutput(Generic[Input], TypedDict): """ @@ -98,6 +101,7 @@ class _EvalCaseDictNoOutput(Generic[Input], TypedDict): id: NotRequired[str | None] _xact_id: NotRequired[str | None] + trial_count: NotRequired[int | None] class _EvalCaseDict(Generic[Input, Output], _EvalCaseDictNoOutput[Input]): @@ -1654,7 +1658,12 @@ async def with_max_concurrency(coro): disable=position is None, ) as pbar: async for datum in pbar: - for trial_index in range(evaluator.trial_count): + if isinstance(datum, dict): + datum_trial_count = datum.get("trial_count") + else: + datum_trial_count = getattr(datum, "trial_count", None) + trial_count = datum_trial_count if datum_trial_count is not None else evaluator.trial_count + for trial_index in range(trial_count): tasks.append(asyncio.create_task(with_max_concurrency(run_evaluator_task(datum, trial_index)))) results = [] diff --git a/py/src/braintrust/test_framework.py b/py/src/braintrust/test_framework.py index 9acf284b0..014317b6d 100644 --- a/py/src/braintrust/test_framework.py +++ b/py/src/braintrust/test_framework.py @@ -243,6 +243,131 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: assert sorted(input_2_trials) == [0, 1] +@pytest.mark.asyncio +async def test_per_input_trial_count_overrides_global(): + """Test that per-input trial_count overrides global trial_count.""" + trial_data: List[tuple] = [] # (input, trial_index) + + def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: + trial_data.append((input_value, hooks.trial_index)) + return input_value * 2 + + # Create evaluator with mixed trial counts + evaluator = Evaluator( + project_name="test-project", + eval_name="test-per-input-trial-count", + data=[ + EvalCase(input=1, expected=2), # Uses global trial_count (2) + EvalCase(input=2, expected=4, trial_count=5), # Overrides to 5 trials + EvalCase(input=3, expected=6, trial_count=1), # Overrides to 1 trial + ], + task=task_with_hooks, + scores=[], + experiment_name=None, + metadata=None, + trial_count=2, # Global default + ) + + # Run evaluator + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + # Should have 8 results total (2 + 5 + 1) + assert len(result.results) == 8 + assert len(trial_data) == 8 + + # Input 1: should use global trial_count (2 trials) + input_1_trials = [trial_idx for inp, trial_idx in trial_data if inp == 1] + assert sorted(input_1_trials) == [0, 1] + + # Input 2: should use per-input trial_count (5 trials) + input_2_trials = [trial_idx for inp, trial_idx in trial_data if inp == 2] + assert sorted(input_2_trials) == [0, 1, 2, 3, 4] + + # Input 3: should use per-input trial_count (1 trial) + input_3_trials = [trial_idx for inp, trial_idx in trial_data if inp == 3] + assert sorted(input_3_trials) == [0] + + +@pytest.mark.asyncio +async def test_per_input_trial_count_without_global(): + """Test that per-input trial_count works without global trial_count.""" + trial_data: List[tuple] = [] # (input, trial_index) + + def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: + trial_data.append((input_value, hooks.trial_index)) + return input_value * 2 + + # Create evaluator with per-input trial counts only (no global) + evaluator = Evaluator( + project_name="test-project", + eval_name="test-per-input-trial-count-no-global", + data=[ + EvalCase(input=1, expected=2), # No trial_count, defaults to 1 + EvalCase(input=2, expected=4, trial_count=3), # Per-input trial_count of 3 + ], + task=task_with_hooks, + scores=[], + experiment_name=None, + metadata=None, + # No global trial_count specified (defaults to 1) + ) + + # Run evaluator + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + # Should have 4 results total (1 + 3) + assert len(result.results) == 4 + assert len(trial_data) == 4 + + # Input 1: should default to 1 trial + input_1_trials = [trial_idx for inp, trial_idx in trial_data if inp == 1] + assert sorted(input_1_trials) == [0] + + # Input 2: should use per-input trial_count (3 trials) + input_2_trials = [trial_idx for inp, trial_idx in trial_data if inp == 2] + assert sorted(input_2_trials) == [0, 1, 2] + + +@pytest.mark.asyncio +async def test_per_input_trial_count_with_dict_data(): + """Test that per-input trial_count works when data is passed as dicts.""" + trial_data: List[tuple] = [] # (input, trial_index) + + def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: + trial_data.append((input_value, hooks.trial_index)) + return input_value * 2 + + # Create evaluator with dict data (instead of EvalCase) + evaluator = Evaluator( + project_name="test-project", + eval_name="test-per-input-trial-count-dict", + data=[ + {"input": 1, "expected": 2}, # Uses global trial_count (2) + {"input": 2, "expected": 4, "trial_count": 4}, # Overrides to 4 trials + ], + task=task_with_hooks, + scores=[], + experiment_name=None, + metadata=None, + trial_count=2, # Global default + ) + + # Run evaluator + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + # Should have 6 results total (2 + 4) + assert len(result.results) == 6 + assert len(trial_data) == 6 + + # Input 1: should use global trial_count (2 trials) + input_1_trials = [trial_idx for inp, trial_idx in trial_data if inp == 1] + assert sorted(input_1_trials) == [0, 1] + + # Input 2: should use per-input trial_count (4 trials) + input_2_trials = [trial_idx for inp, trial_idx in trial_data if inp == 2] + assert sorted(input_2_trials) == [0, 1, 2, 3] + + @pytest.mark.vcr @pytest.mark.asyncio async def test_scorer_spans_have_purpose_attribute(with_memory_logger, with_simulate_login):