Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class EvalCase(SerializableDataClass, Generic[Input, Output]):
_xact_id: str | None = None
created: str | None = None

# The number of times to run the evaluator for this specific input.
trial_count: int | None = None


class _EvalCaseDictNoOutput(Generic[Input], TypedDict):
"""
Expand All @@ -98,6 +101,7 @@ class _EvalCaseDictNoOutput(Generic[Input], TypedDict):

id: NotRequired[str | None]
_xact_id: NotRequired[str | None]
trial_count: NotRequired[int | None]


class _EvalCaseDict(Generic[Input, Output], _EvalCaseDictNoOutput[Input]):
Expand Down Expand Up @@ -1654,7 +1658,12 @@ async def with_max_concurrency(coro):
disable=position is None,
) as pbar:
async for datum in pbar:
for trial_index in range(evaluator.trial_count):
if isinstance(datum, dict):
datum_trial_count = datum.get("trial_count")
else:
datum_trial_count = getattr(datum, "trial_count", None)
trial_count = datum_trial_count if datum_trial_count is not None else evaluator.trial_count
for trial_index in range(trial_count):
tasks.append(asyncio.create_task(with_max_concurrency(run_evaluator_task(datum, trial_index))))

results = []
Expand Down
125 changes: 125 additions & 0 deletions py/src/braintrust/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,131 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int:
assert sorted(input_2_trials) == [0, 1]


@pytest.mark.asyncio
async def test_per_input_trial_count_overrides_global():
"""Test that per-input trial_count overrides global trial_count."""
trial_data: List[tuple] = [] # (input, trial_index)

def task_with_hooks(input_value: int, hooks: EvalHooks) -> int:
trial_data.append((input_value, hooks.trial_index))
return input_value * 2

# Create evaluator with mixed trial counts
evaluator = Evaluator(
project_name="test-project",
eval_name="test-per-input-trial-count",
data=[
EvalCase(input=1, expected=2), # Uses global trial_count (2)
EvalCase(input=2, expected=4, trial_count=5), # Overrides to 5 trials
EvalCase(input=3, expected=6, trial_count=1), # Overrides to 1 trial
],
task=task_with_hooks,
scores=[],
experiment_name=None,
metadata=None,
trial_count=2, # Global default
)

# Run evaluator
result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[])

# Should have 8 results total (2 + 5 + 1)
assert len(result.results) == 8
assert len(trial_data) == 8

# Input 1: should use global trial_count (2 trials)
input_1_trials = [trial_idx for inp, trial_idx in trial_data if inp == 1]
assert sorted(input_1_trials) == [0, 1]

# Input 2: should use per-input trial_count (5 trials)
input_2_trials = [trial_idx for inp, trial_idx in trial_data if inp == 2]
assert sorted(input_2_trials) == [0, 1, 2, 3, 4]

# Input 3: should use per-input trial_count (1 trial)
input_3_trials = [trial_idx for inp, trial_idx in trial_data if inp == 3]
assert sorted(input_3_trials) == [0]


@pytest.mark.asyncio
async def test_per_input_trial_count_without_global():
"""Test that per-input trial_count works without global trial_count."""
trial_data: List[tuple] = [] # (input, trial_index)

def task_with_hooks(input_value: int, hooks: EvalHooks) -> int:
trial_data.append((input_value, hooks.trial_index))
return input_value * 2

# Create evaluator with per-input trial counts only (no global)
evaluator = Evaluator(
project_name="test-project",
eval_name="test-per-input-trial-count-no-global",
data=[
EvalCase(input=1, expected=2), # No trial_count, defaults to 1
EvalCase(input=2, expected=4, trial_count=3), # Per-input trial_count of 3
],
task=task_with_hooks,
scores=[],
experiment_name=None,
metadata=None,
# No global trial_count specified (defaults to 1)
)

# Run evaluator
result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[])

# Should have 4 results total (1 + 3)
assert len(result.results) == 4
assert len(trial_data) == 4

# Input 1: should default to 1 trial
input_1_trials = [trial_idx for inp, trial_idx in trial_data if inp == 1]
assert sorted(input_1_trials) == [0]

# Input 2: should use per-input trial_count (3 trials)
input_2_trials = [trial_idx for inp, trial_idx in trial_data if inp == 2]
assert sorted(input_2_trials) == [0, 1, 2]


@pytest.mark.asyncio
async def test_per_input_trial_count_with_dict_data():
"""Test that per-input trial_count works when data is passed as dicts."""
trial_data: List[tuple] = [] # (input, trial_index)

def task_with_hooks(input_value: int, hooks: EvalHooks) -> int:
trial_data.append((input_value, hooks.trial_index))
return input_value * 2

# Create evaluator with dict data (instead of EvalCase)
evaluator = Evaluator(
project_name="test-project",
eval_name="test-per-input-trial-count-dict",
data=[
{"input": 1, "expected": 2}, # Uses global trial_count (2)
{"input": 2, "expected": 4, "trial_count": 4}, # Overrides to 4 trials
],
task=task_with_hooks,
scores=[],
experiment_name=None,
metadata=None,
trial_count=2, # Global default
)

# Run evaluator
result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[])

# Should have 6 results total (2 + 4)
assert len(result.results) == 6
assert len(trial_data) == 6

# Input 1: should use global trial_count (2 trials)
input_1_trials = [trial_idx for inp, trial_idx in trial_data if inp == 1]
assert sorted(input_1_trials) == [0, 1]

# Input 2: should use per-input trial_count (4 trials)
input_2_trials = [trial_idx for inp, trial_idx in trial_data if inp == 2]
assert sorted(input_2_trials) == [0, 1, 2, 3]


@pytest.mark.vcr
@pytest.mark.asyncio
async def test_scorer_spans_have_purpose_attribute(with_memory_logger, with_simulate_login):
Expand Down