Skip to content

Commit 2e4160d

Browse files
committed
Fix evaluation schema strictness
1 parent b18bf91 commit 2e4160d

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

packages/core/src/services/evaluationsV2/llm/shared.test.ts

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { LogSources, Providers } from '@latitude-data/constants'
22
import { ChainError, RunErrorCodes } from '@latitude-data/constants/errors'
3-
import { beforeEach, describe, expect, it, vi } from 'vitest'
3+
import { Chain } from 'promptl-ai'
4+
import { beforeEach, describe, expect, it, MockInstance, vi } from 'vitest'
45
import {
56
EvaluationType,
67
EvaluationV2,
@@ -20,6 +21,10 @@ import { buildPrompt, promptSchema } from './binary'
2021
import { runPrompt } from './shared'
2122

2223
describe('runPrompt', () => {
24+
let mocks: {
25+
runChain: MockInstance
26+
}
27+
2328
let workspace: Workspace
2429
let commit: Commit
2530
let provider: ProviderApiKey
@@ -99,26 +104,28 @@ describe('runPrompt', () => {
99104
workspace: workspace,
100105
})
101106

102-
vi.spyOn(chains, 'runChain').mockImplementation((args) => {
103-
;(args.chain as any)._completed = true
107+
mocks = {
108+
runChain: vi.spyOn(chains, 'runChain').mockImplementation((args) => {
109+
;(args.chain as any)._completed = true
104110

105-
return {
106-
errorableUuid: resultUuid,
107-
error: Promise.resolve(undefined),
108-
response: Promise.resolve({
109-
streamType: 'object',
110-
object: { passed: true, reason: 'reason' },
111-
text: 'text',
112-
usage: {
113-
promptTokens: 10,
114-
completionTokens: 10,
115-
totalTokens: 20,
116-
},
117-
documentLogUuid: providerLog.documentLogUuid,
118-
providerLog: providerLog,
119-
}),
120-
} as any
121-
})
111+
return {
112+
errorableUuid: resultUuid,
113+
error: Promise.resolve(undefined),
114+
response: Promise.resolve({
115+
streamType: 'object',
116+
object: { passed: true, reason: 'reason' },
117+
text: 'text',
118+
usage: {
119+
promptTokens: 10,
120+
completionTokens: 10,
121+
totalTokens: 20,
122+
},
123+
documentLogUuid: providerLog.documentLogUuid,
124+
providerLog: providerLog,
125+
}),
126+
} as any
127+
}),
128+
}
122129
})
123130

124131
it('fails when compiling the prompt', async () => {
@@ -139,10 +146,12 @@ describe('runPrompt', () => {
139146
message: `Unexpected EOF. Expected '}}' but did not find it.`,
140147
}),
141148
)
149+
150+
expect(mocks.runChain).not.toHaveBeenCalled()
142151
})
143152

144153
it('fails when running the prompt', async () => {
145-
vi.spyOn(chains, 'runChain').mockImplementation((args) => {
154+
mocks.runChain = vi.spyOn(chains, 'runChain').mockImplementation((args) => {
146155
;(args.chain as any)._completed = true
147156

148157
return {
@@ -174,10 +183,12 @@ describe('runPrompt', () => {
174183
message: 'Failed!',
175184
}),
176185
)
186+
187+
expect(mocks.runChain).toHaveBeenCalled()
177188
})
178189

179190
it('fails when verdict cannot be parsed from a normal response', async () => {
180-
vi.spyOn(chains, 'runChain').mockImplementation((args) => {
191+
mocks.runChain = vi.spyOn(chains, 'runChain').mockImplementation((args) => {
181192
;(args.chain as any)._completed = true
182193

183194
return {
@@ -215,10 +226,12 @@ describe('runPrompt', () => {
215226
message: 'Evaluation conversation response is not an object',
216227
}),
217228
)
229+
230+
expect(mocks.runChain).toHaveBeenCalled()
218231
})
219232

220233
it('fails when verdict does not match the schema', async () => {
221-
vi.spyOn(chains, 'runChain').mockImplementation((args) => {
234+
mocks.runChain = vi.spyOn(chains, 'runChain').mockImplementation((args) => {
222235
;(args.chain as any)._completed = true
223236

224237
return {
@@ -265,6 +278,8 @@ describe('runPrompt', () => {
265278
]`,
266279
}),
267280
)
281+
282+
expect(mocks.runChain).toHaveBeenCalled()
268283
})
269284

270285
it('succeeds when running a normal prompt', async () => {
@@ -278,6 +293,7 @@ describe('runPrompt', () => {
278293
commit: commit,
279294
workspace: workspace,
280295
})
296+
const chain = mocks.runChain.mock.calls[0][0].chain as Chain
281297

282298
expect(result).toEqual({
283299
response: {
@@ -296,5 +312,8 @@ describe('runPrompt', () => {
296312
},
297313
verdict: { passed: true, reason: 'reason' },
298314
})
315+
expect(mocks.runChain).toHaveBeenCalledOnce()
316+
expect(chain.rawText.includes('structuredOutputs: true')).toBe(true)
317+
expect(chain.rawText.includes('strictJsonSchema: true')).toBe(true)
299318
})
300319
})

packages/core/src/services/evaluationsV2/llm/shared.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ export async function buildLlmEvaluationRunFunction<
7474
if (schema) {
7575
prompt = updatePromptMetadata(prompt, {
7676
schema: z.toJSONSchema(schema, { target: 'openapi-3.0' }),
77+
structuredOutputs: true,
78+
strictJsonSchema: true,
7779
})
7880
}
7981

0 commit comments

Comments
 (0)