diff --git a/js/src/framework.test.ts b/js/src/framework.test.ts index 5cdc2137f..a6650a05e 100644 --- a/js/src/framework.test.ts +++ b/js/src/framework.test.ts @@ -528,6 +528,99 @@ test("trialIndex with multiple inputs", async () => { expect(input2Trials).toEqual([0, 1]); }); +test("per-input trialCount overrides global trialCount", async () => { + const trialData: Array<{ input: number; trialIndex: number }> = []; + + const { results } = await runEvaluator( + null, + { + projectName: "proj", + evalName: "eval", + data: [ + { input: 1, expected: 2 }, + { input: 2, expected: 4, trialCount: 5 }, + { input: 3, expected: 6, trialCount: 1 }, + ], + task: async (input: number, { trialIndex }) => { + trialData.push({ input, trialIndex }); + return input * 2; + }, + scores: [], + trialCount: 2, + }, + new NoopProgressReporter(), + [], + undefined, + undefined, + true, + ); + + expect(results).toHaveLength(8); + expect(trialData).toHaveLength(8); + + // Input 1: should use global trialCount (2 trials) + const input1Trials = trialData + .filter((d) => d.input === 1) + .map((d) => d.trialIndex) + .sort(); + expect(input1Trials).toEqual([0, 1]); + + // Input 2: should use per-input trialCount (5 trials) + const input2Trials = trialData + .filter((d) => d.input === 2) + .map((d) => d.trialIndex) + .sort(); + expect(input2Trials).toEqual([0, 1, 2, 3, 4]); + + // Input 3: should use per-input trialCount (1 trial) + const input3Trials = trialData + .filter((d) => d.input === 3) + .map((d) => d.trialIndex) + .sort(); + expect(input3Trials).toEqual([0]); +}); + +test("per-input trialCount works without global trialCount", async () => { + const trialData: Array<{ input: number; trialIndex: number }> = []; + + const { results } = await runEvaluator( + null, + { + projectName: "proj", + evalName: "eval", + data: [ + { input: 1, expected: 2 }, + { input: 2, expected: 4, trialCount: 3 }, + ], + task: async (input: number, { trialIndex }) => { + trialData.push({ input, trialIndex }); + return input * 2; + }, + scores: [], + }, + new NoopProgressReporter(), + [], + undefined, + undefined, + true, + ); + + expect(results).toHaveLength(4); + expect(trialData).toHaveLength(4); + + const input1Trials = trialData + .filter((d) => d.input === 1) + .map((d) => d.trialIndex) + .sort(); + expect(input1Trials).toEqual([0]); + + const input2Trials = trialData + .filter((d) => d.input === 2) + .map((d) => d.trialIndex) + .sort(); + expect(input2Trials).toEqual([0, 1, 2]); +}); + test("Eval with noSendLogs: true runs locally without creating experiment", async () => { const memoryLogger = _exportsForTestingOnly.useTestBackgroundLogger(); diff --git a/js/src/framework.ts b/js/src/framework.ts index f0a8258c2..f225ac4dc 100644 --- a/js/src/framework.ts +++ b/js/src/framework.ts @@ -1037,7 +1037,7 @@ async function runEvaluatorInternal( objectId: parentComponents?.data.object_id ?? (experimentIdPromise - ? (await experimentIdPromise) ?? "" + ? ((await experimentIdPromise) ?? "") : ""), rootSpanId: rootSpan.rootSpanId, ensureSpansFlushed, @@ -1314,7 +1314,7 @@ async function runEvaluatorInternal( if (!filters.every((f) => evaluateFilter(datum, f))) { continue; } - const trialCount = evaluator.trialCount ?? 1; + const trialCount = datum.trialCount ?? evaluator.trialCount ?? 1; for (let trialIndex = 0; trialIndex < trialCount; trialIndex++) { if (cancelled) { break; diff --git a/js/src/logger.ts b/js/src/logger.ts index fecc7cf7a..cb1518453 100644 --- a/js/src/logger.ts +++ b/js/src/logger.ts @@ -5586,6 +5586,8 @@ export type EvalCase = { created?: string | null; // This field is used to help re-run a particular experiment row. upsert_id?: string; + // The number of times to run the evaluator for this specific input. + trialCount?: number; } & (Expected extends void ? object : { expected: Expected }) & (Metadata extends void ? object : { metadata: Metadata });