diff --git a/index.ts b/index.ts index e6d808b..c1d865f 100644 --- a/index.ts +++ b/index.ts @@ -3,7 +3,11 @@ import chalk from "chalk"; const isProd = process.env.NODE_ENV === "production"; const services = [ - // { name: "api", cmd: ["bun", "run", ...(isProd ? [] : ["--hot"]), "packages/api_platform/src/index.ts"], color: chalk.blue }, + { + name: "api", + cmd: ["bun", "run", ...(isProd ? [] : ["--hot"]), "packages/api_platform/src/index.ts"], + color: chalk.blue, + }, { name: "backend", cmd: ["bun", "run", ...(isProd ? [] : ["--hot"]), "packages/backend/src/index.ts"], diff --git a/packages/api_platform/src/index.ts b/packages/api_platform/src/index.ts index 8edb866..bfa8cfb 100644 --- a/packages/api_platform/src/index.ts +++ b/packages/api_platform/src/index.ts @@ -1,6 +1,7 @@ import { HttpMiddleware, HttpServer } from "@effect/platform"; import { BunContext, BunHttpServer, BunRuntime } from "@effect/platform-bun"; import { Effect, flow, Layer } from "effect"; +import { fromEnv as LedgerServiceLive } from "ledger"; import { ResolverServiceLive } from "resolver"; import { VaultServiceLive } from "vault"; @@ -29,6 +30,7 @@ const AllServices = Layer.mergeAll( AppConfigLive, VaultServiceLive, ResolverServiceLive, + LedgerServiceLive, BunContext.layer, ); const AllServicesAndHttpServer = Layer.mergeAll(AllServices, HttpServerLayer); diff --git a/packages/api_platform/src/middlewares.ts b/packages/api_platform/src/middlewares.ts index cd21a1c..735643b 100644 --- a/packages/api_platform/src/middlewares.ts +++ b/packages/api_platform/src/middlewares.ts @@ -1,4 +1,4 @@ -import type { ProviderModelPair } from "resolver"; +import type { VerifyApiKeyResult } from "common"; import { HttpMiddleware, HttpServerRequest } from "@effect/platform"; import { BadRequest, Unauthorized } from "@effect/platform/HttpApiError"; @@ -24,15 +24,6 @@ class ApiKeyVerificationError extends Data.TaggedError("ApiKeyVerificationError" message?: string; }> {} -interface VerifyResponse { - valid: boolean; - userId?: string; - providers?: string[]; - fallbackProviderModelPair?: ProviderModelPair; - analysisTarget?: string; - error?: string; -} - const verifyApiKey = (backendUrl: string, apiKey: string) => Effect.tryPromise({ try: async () => { @@ -44,7 +35,7 @@ const verifyApiKey = (backendUrl: string, apiKey: string) => if (!response.ok) { throw new Error(`Backend returned ${response.status}`); } - return (await response.json()) as VerifyResponse; + return (await response.json()) as VerifyApiKeyResult; }, catch: (error) => new ApiKeyVerificationError({ diff --git a/packages/api_platform/src/routes/v1/responses.ts b/packages/api_platform/src/routes/v1/responses.ts index 31733ee..98661b9 100644 --- a/packages/api_platform/src/routes/v1/responses.ts +++ b/packages/api_platform/src/routes/v1/responses.ts @@ -103,18 +103,11 @@ export const responsesRouter = HttpRouter.empty.pipe( Effect.gen(function* () { const createResponseBody = yield* HttpServerRequest.schemaBodyJson(CreateResponseBodySchema); yield* validateCreateResponseBody(createResponseBody); - const { userId, userProviders, fallbackProviderModelPair, analysisTarget } = - yield* RequestContext; + const params = yield* RequestContext; if (createResponseBody.stream === true) { const db = yield* DatabaseService; - const sseStream = yield* ResponsesService.createStream( - createResponseBody, - userId, - userProviders, - fallbackProviderModelPair, - analysisTarget, - ); + const sseStream = yield* ResponsesService.createStream(createResponseBody, params); return HttpServerResponse.stream( sseStream.pipe(Stream.provideService(DatabaseService, db)), { @@ -128,19 +121,13 @@ export const responsesRouter = HttpRouter.empty.pipe( ); } - const responsesObject = yield* ResponsesService.create( - createResponseBody, - userId, - userProviders, - fallbackProviderModelPair, - analysisTarget, - ); + const responsesObject = yield* ResponsesService.create(createResponseBody, params); return yield* HttpServerResponse.json(responsesObject); }).pipe( Effect.catchTags({ - RequestValidationError: (err) => + RequestValidationError: (err: RequestValidationError) => HttpServerResponse.json({ error: { message: err.message } }, { status: 400 }), - ResponseServiceError: (err) => + ResponseServiceError: (err: ResponsesService.ResponseServiceError) => HttpServerResponse.json( { error: { message: err.message ?? "Internal Server Error" } }, { status: 500, headers: { "x-enfinyte-error": `${err.name}: ${err}` } }, diff --git a/packages/api_platform/src/services/ai/__tests__/stream-events.test.ts b/packages/api_platform/src/services/ai/__tests__/stream-events.test.ts new file mode 100644 index 0000000..7bcb3c7 --- /dev/null +++ b/packages/api_platform/src/services/ai/__tests__/stream-events.test.ts @@ -0,0 +1,573 @@ +import { describe, it, expect } from "bun:test"; +import { streamToEvents } from "../stream-events"; +import type { StreamingEvent } from "common"; + +type MockPart = { type: string; [key: string]: unknown }; + +async function* mockTextStreamParts(parts: MockPart[]): AsyncIterable { + for (const part of parts) { + yield part; + } +} + +async function collectEvents( + stream: AsyncIterable, +): Promise { + const collected: StreamingEvent[] = []; + for await (const batch of stream) { + collected.push(...batch); + } + return collected; +} + +describe("streamToEvents", () => { + describe("text message lifecycle", () => { + it("produces correct events for text-start → text-delta → text-end", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "Hello" }, + { type: "text-delta", text: " world" }, + { type: "text-end" }, + { type: "finish", finishReason: "stop" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-test-1", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(7); + expect(collected[0]?.type).toBe("response.output_item.added"); + expect(collected[1]?.type).toBe("response.content_part.added"); + expect(collected[2]?.type).toBe("response.output_text.delta"); + expect(collected[3]?.type).toBe("response.output_text.delta"); + expect(collected[4]?.type).toBe("response.output_text.done"); + expect(collected[5]?.type).toBe("response.content_part.done"); + expect(collected[6]?.type).toBe("response.output_item.done"); + + const addedItem = (collected[0] as StreamingEvent & { item: Record }).item; + expect(addedItem.id).toBeString(); + expect(addedItem.status).toBe("in_progress"); + expect(addedItem.role).toBe("assistant"); + expect(addedItem.content).toEqual([]); + expect(addedItem).not.toHaveProperty("type"); + + const doneItem = (collected[6] as StreamingEvent & { item: Record }).item; + expect(doneItem.id).toBeString(); + expect(doneItem.status).toBe("completed"); + expect(doneItem.role).toBe("assistant"); + expect((doneItem.content as Array<{ type: string; text: string }>)[0]?.type).toBe("output_text"); + expect((doneItem.content as Array<{ type: string; text: string }>)[0]?.text).toBe("Hello world"); + expect(doneItem).not.toHaveProperty("type"); + }); + + it("accumulates text deltas and includes full text in done events", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "foo" }, + { type: "text-delta", text: "bar" }, + { type: "text-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-test-2", + ); + const collected = await collectEvents(events); + + const textDone = collected.find((e) => e.type === "response.output_text.done"); + expect(textDone).toBeDefined(); + expect((textDone as { text: string }).text).toBe("foobar"); + + const contentDone = collected.find((e) => e.type === "response.content_part.done"); + expect(contentDone).toBeDefined(); + expect((contentDone as { part: { text: string } }).part.text).toBe("foobar"); + }); + + it("sets output_index = 0 and content_index = 0 for first text item", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "x" }, + { type: "text-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-idx", + ); + const collected = await collectEvents(events); + + const added = collected[0] as StreamingEvent & { output_index: number }; + expect(added.output_index).toBe(0); + + const delta = collected[2] as StreamingEvent & { + output_index: number; + content_index: number; + }; + expect(delta.output_index).toBe(0); + expect(delta.content_index).toBe(0); + }); + + it("assigns consistent item_id to all events within an item", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "hi" }, + { type: "text-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-id", + ); + const collected = await collectEvents(events); + + const addedEvent = collected[0] as StreamingEvent & { item: { id: string; [key: string]: unknown } }; + const itemId = addedEvent.item.id; + expect(itemId).toBeTruthy(); + expect(addedEvent.item).toHaveProperty("status", "in_progress"); + expect(addedEvent.item).toHaveProperty("role", "assistant"); + + for (const event of collected) { + if ("item_id" in event) { + expect((event as { item_id: string }).item_id).toBe(itemId); + } + if ("item" in event) { + expect((event as { item: { id: string } }).item.id).toBe(itemId); + } + } + }); + }); + + describe("function call lifecycle", () => { + it("produces correct events for tool-input-start → tool-input-delta → tool-input-end", async () => { + const stream = mockTextStreamParts([ + { type: "tool-input-start", toolCallId: "call_abc", toolName: "get_weather" }, + { type: "tool-input-delta", argsTextDelta: '{"loc' }, + { type: "tool-input-delta", argsTextDelta: 'ation":"NYC"}' }, + { type: "tool-input-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-tool-1", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(7); + expect(collected[0]?.type).toBe("response.output_item.added"); + expect(collected[1]?.type).toBe("response.content_part.added"); + expect(collected[2]?.type).toBe("response.function_call_arguments.delta"); + expect(collected[3]?.type).toBe("response.function_call_arguments.delta"); + expect(collected[4]?.type).toBe("response.function_call_arguments.done"); + expect(collected[5]?.type).toBe("response.content_part.done"); + expect(collected[6]?.type).toBe("response.output_item.done"); + + const addedItem = (collected[0] as StreamingEvent & { item: Record }).item; + expect(addedItem.id).toBeString(); + expect(addedItem.status).toBe("in_progress"); + expect(addedItem.call_id).toBe("call_abc"); + expect(addedItem.name).toBe("get_weather"); + expect(addedItem.arguments).toBe(""); + expect(addedItem).not.toHaveProperty("type"); + + const doneItem = (collected[6] as StreamingEvent & { item: Record }).item; + expect(doneItem.id).toBeString(); + expect(doneItem.status).toBe("completed"); + expect(doneItem.call_id).toBe("call_abc"); + expect(doneItem.name).toBe("get_weather"); + expect(doneItem.arguments).toBe('{"location":"NYC"}'); + expect(doneItem).not.toHaveProperty("type"); + }); + + it("accumulates arguments and includes full args in done event", async () => { + const stream = mockTextStreamParts([ + { type: "tool-input-start", toolCallId: "call_1", toolName: "fn" }, + { type: "tool-input-delta", argsTextDelta: '{"a":' }, + { type: "tool-input-delta", argsTextDelta: "1}" }, + { type: "tool-input-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-tool-acc", + ); + const collected = await collectEvents(events); + + const argsDone = collected.find( + (e) => e.type === "response.function_call_arguments.done", + ); + expect(argsDone).toBeDefined(); + expect((argsDone as { arguments: string }).arguments).toBe('{"a":1}'); + }); + + it("stores function call in accumulated output items", async () => { + const stream = mockTextStreamParts([ + { type: "tool-input-start", toolCallId: "call_xyz", toolName: "search" }, + { type: "tool-input-delta", argsTextDelta: "{}" }, + { type: "tool-input-end" }, + ]); + + const { events, getAccumulatedState } = streamToEvents( + stream as never, + "resp-tool-state", + ); + await collectEvents(events); + + const state = getAccumulatedState(); + expect(state.outputItems).toHaveLength(1); + expect(state.outputItems[0]?.type).toBe("function_call"); + + const fc = state.outputItems[0] as { + type: string; + call_id: string; + name: string; + arguments: string; + }; + expect(fc.call_id).toBe("call_xyz"); + expect(fc.name).toBe("search"); + expect(fc.arguments).toBe("{}"); + }); + }); + + describe("reasoning lifecycle", () => { + it("produces correct events for reasoning-start → reasoning-delta → reasoning-end", async () => { + const stream = mockTextStreamParts([ + { type: "reasoning-start" }, + { type: "reasoning-delta", text: "Let me " }, + { type: "reasoning-delta", text: "think..." }, + { type: "reasoning-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-reason-1", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(7); + expect(collected[0]?.type).toBe("response.output_item.added"); + expect(collected[1]?.type).toBe("response.content_part.added"); + expect(collected[2]?.type).toBe("response.reasoning.delta"); + expect(collected[3]?.type).toBe("response.reasoning.delta"); + expect(collected[4]?.type).toBe("response.reasoning.done"); + expect(collected[5]?.type).toBe("response.content_part.done"); + expect(collected[6]?.type).toBe("response.output_item.done"); + + const addedItem = (collected[0] as StreamingEvent & { item: Record }).item; + expect(addedItem.id).toBeString(); + expect(addedItem.summary).toEqual([]); + expect(addedItem).not.toHaveProperty("type"); + expect(addedItem).not.toHaveProperty("status"); + + const doneItem = (collected[6] as StreamingEvent & { item: Record }).item; + expect(doneItem.id).toBeString(); + expect(doneItem.summary).toEqual([]); + expect((doneItem.content as Array<{ type: string; text: string }>)[0]?.type).toBe("reasoning"); + expect((doneItem.content as Array<{ type: string; text: string }>)[0]?.text).toBe("Let me think..."); + expect(doneItem).not.toHaveProperty("type"); + expect(doneItem).not.toHaveProperty("status"); + }); + + it("accumulates reasoning text in done event", async () => { + const stream = mockTextStreamParts([ + { type: "reasoning-start" }, + { type: "reasoning-delta", text: "Step 1. " }, + { type: "reasoning-delta", text: "Step 2." }, + { type: "reasoning-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-reason-acc", + ); + const collected = await collectEvents(events); + + const reasoningDone = collected.find((e) => e.type === "response.reasoning.done"); + expect(reasoningDone).toBeDefined(); + expect((reasoningDone as { text: string }).text).toBe("Step 1. Step 2."); + }); + + it("stores reasoning in accumulated output items", async () => { + const stream = mockTextStreamParts([ + { type: "reasoning-start" }, + { type: "reasoning-delta", text: "hmm" }, + { type: "reasoning-end" }, + ]); + + const { events, getAccumulatedState } = streamToEvents( + stream as never, + "resp-reason-state", + ); + await collectEvents(events); + + const state = getAccumulatedState(); + expect(state.outputItems).toHaveLength(1); + expect(state.outputItems[0]?.type).toBe("reasoning"); + }); + }); + + describe("multiple output items", () => { + it("increments output_index for each new item", async () => { + const stream = mockTextStreamParts([ + { type: "reasoning-start" }, + { type: "reasoning-delta", text: "think" }, + { type: "reasoning-end" }, + { type: "text-start" }, + { type: "text-delta", text: "Answer" }, + { type: "text-end" }, + { type: "tool-input-start", toolCallId: "call_1", toolName: "fn" }, + { type: "tool-input-delta", argsTextDelta: "{}" }, + { type: "tool-input-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-multi", + ); + const collected = await collectEvents(events); + + const addedEvents = collected.filter( + (e) => e.type === "response.output_item.added", + ) as Array; + + expect(addedEvents).toHaveLength(3); + expect(addedEvents[0]?.output_index).toBe(0); + expect(addedEvents[1]?.output_index).toBe(1); + expect(addedEvents[2]?.output_index).toBe(2); + }); + + it("accumulates all output items in state", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "Hello" }, + { type: "text-end" }, + { type: "tool-input-start", toolCallId: "call_2", toolName: "lookup" }, + { type: "tool-input-delta", argsTextDelta: '{"q":"x"}' }, + { type: "tool-input-end" }, + ]); + + const { events, getAccumulatedState } = streamToEvents( + stream as never, + "resp-multi-state", + ); + await collectEvents(events); + + const state = getAccumulatedState(); + expect(state.outputItems).toHaveLength(2); + expect(state.outputItems[0]?.type).toBe("message"); + expect(state.outputItems[1]?.type).toBe("function_call"); + }); + }); + + describe("sequence numbers", () => { + it("assigns monotonically increasing sequence numbers starting at 0", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "A" }, + { type: "text-delta", text: "B" }, + { type: "text-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-seq", + ); + const collected = await collectEvents(events); + + const seqNumbers = collected.map((e) => e.sequence_number); + expect(seqNumbers[0]).toBe(0); + + for (let i = 1; i < seqNumbers.length; i++) { + const prev = seqNumbers[i - 1]; + const curr = seqNumbers[i]; + expect(curr).toBe((prev ?? 0) + 1); + } + }); + + it("continues sequence numbers across multiple items", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "hi" }, + { type: "text-end" }, + { type: "reasoning-start" }, + { type: "reasoning-delta", text: "hmm" }, + { type: "reasoning-end" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-seq-multi", + ); + const collected = await collectEvents(events); + + const seqNumbers = collected.map((e) => e.sequence_number); + for (let i = 1; i < seqNumbers.length; i++) { + const prev = seqNumbers[i - 1]; + const curr = seqNumbers[i]; + expect(curr).toBe((prev ?? 0) + 1); + } + }); + }); + + describe("error handling", () => { + it("produces error event from error part with object error", async () => { + const stream = mockTextStreamParts([ + { + type: "error", + error: { + type: "api_error", + code: "rate_limit_exceeded", + message: "Too many requests", + param: null, + }, + }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-err", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(1); + const errorEvent = collected[0] as StreamingEvent & { + error: { type: string; code: string | null; message: string; param: string | null }; + }; + expect(errorEvent.type).toBe("error"); + expect(errorEvent.error.type).toBe("api_error"); + expect(errorEvent.error.code).toBe("rate_limit_exceeded"); + expect(errorEvent.error.message).toBe("Too many requests"); + expect(errorEvent.error.param).toBeNull(); + }); + + it("handles string error value", async () => { + const stream = mockTextStreamParts([ + { type: "error", error: "Something went wrong" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-err-str", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(1); + const errorEvent = collected[0] as StreamingEvent & { + error: { type: string; message: string }; + }; + expect(errorEvent.type).toBe("error"); + expect(errorEvent.error.type).toBe("error"); + expect(errorEvent.error.message).toBe("Something went wrong"); + }); + + it("handles error with name field as type fallback", async () => { + const stream = mockTextStreamParts([ + { + type: "error", + error: { name: "TimeoutError", message: "Request timed out" }, + }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-err-name", + ); + const collected = await collectEvents(events); + + const errorEvent = collected[0] as StreamingEvent & { + error: { type: string; message: string }; + }; + expect(errorEvent.error.type).toBe("TimeoutError"); + expect(errorEvent.error.message).toBe("Request timed out"); + }); + + it("handles null error value with defaults", async () => { + const stream = mockTextStreamParts([ + { type: "error", error: null }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-err-null", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(1); + const errorEvent = collected[0] as StreamingEvent & { + error: { type: string; message: string }; + }; + expect(errorEvent.error.type).toBe("error"); + expect(errorEvent.error.message).toBe("Unknown error"); + }); + }); + + describe("getAccumulatedState", () => { + it("returns correct state after full text lifecycle", async () => { + const stream = mockTextStreamParts([ + { type: "text-start" }, + { type: "text-delta", text: "hello" }, + { type: "text-end" }, + ]); + + const { events, getAccumulatedState } = streamToEvents( + stream as never, + "resp-state", + ); + await collectEvents(events); + + const state = getAccumulatedState(); + expect(state.outputItems).toHaveLength(1); + expect(state.currentItemId).toBeNull(); + expect(state.currentItemType).toBeNull(); + expect(state.outputIndex).toBe(0); + expect(state.sequenceNumber).toBe(6); + }); + + it("returns empty state for empty stream", async () => { + const stream = mockTextStreamParts([]); + + const { events, getAccumulatedState } = streamToEvents( + stream as never, + "resp-empty", + ); + await collectEvents(events); + + const state = getAccumulatedState(); + expect(state.outputItems).toHaveLength(0); + expect(state.sequenceNumber).toBe(0); + expect(state.outputIndex).toBe(-1); + expect(state.currentItemId).toBeNull(); + }); + }); + + describe("edge cases", () => { + it("ignores finish part (no events emitted)", async () => { + const stream = mockTextStreamParts([ + { type: "finish", finishReason: "stop" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-finish", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(0); + }); + + it("ignores unknown part types", async () => { + const stream = mockTextStreamParts([ + { type: "some-unknown-type" }, + ]); + + const { events } = streamToEvents( + stream as never, + "resp-unknown", + ); + const collected = await collectEvents(events); + + expect(collected).toHaveLength(0); + }); + }); +}); diff --git a/packages/api_platform/src/services/ai/adapters.ts b/packages/api_platform/src/services/ai/adapters.ts new file mode 100644 index 0000000..79dda49 --- /dev/null +++ b/packages/api_platform/src/services/ai/adapters.ts @@ -0,0 +1,482 @@ +import { Effect } from "effect"; +import type { CreateResponseBody, CreateResponseBodyInputItem } from "common"; +import type { + FilePart, + ImagePart, + ModelMessage, + SystemModelMessage, + TextPart, + ToolResultPart, +} from "ai"; +import { jsonSchema, Output } from "ai"; +import type { ToolSet } from "ai"; +import { AIServiceError } from "."; +import { isNotNullable } from "effect/Predicate"; +import { detectMimeTypeFromBase64EncodedString, detectMimeTypeFromURL } from "../../utils"; + +export const inputToMessages = ( + createResponseBody: CreateResponseBody, +) => + Effect.gen(function* () { + const { input, instructions } = createResponseBody; + + if (!input) { + return yield* new AIServiceError({ + message: "Input field is required for message-based models.", + }); + } + + const instructionsAsSystemMessage: SystemModelMessage[] = instructions + ? [ + { + role: "system", + content: instructions, + }, + ] + : []; + + if (typeof input === "string") { + return yield* Effect.succeed([ + ...instructionsAsSystemMessage, + { + role: "user", + content: input, + }, + ] satisfies ModelMessage[] as ModelMessage[]); + } + + const inputItemsAsModelMessage = yield* Effect.all(input.map(convertInputItemToModelMessage)); + + return yield* Effect.succeed( + [...instructionsAsSystemMessage, ...inputItemsAsModelMessage.flat()].filter( + isNotNullable, + ) satisfies ModelMessage[] as ModelMessage[], + ); + }); + +const convertInputItemToModelMessage = ( + createResponseBodyInputItem: CreateResponseBodyInputItem, +): Effect.Effect => + Effect.gen(function* () { + switch (createResponseBodyInputItem.type) { + case "message": { + const role = createResponseBodyInputItem.role; + const content = createResponseBodyInputItem.content; + + switch (role) { + case "system": + case "developer": { + const providerOptions = + role === "developer" + ? { + providerOptions: { + openai: { systemMessageMode: "developer" }, + }, + } + : {}; + + if (typeof content === "string") + return [ + { + role: "system", + content, + ...providerOptions, + }, + ] satisfies ModelMessage[]; + else + return content + .filter((contentItem) => contentItem.type === "input_text") + .map((contentItem) => ({ + role: "system", + content: contentItem.text, + ...providerOptions, + })) satisfies ModelMessage[]; + } + case "user": { + if (typeof content === "string") { + return [ + { + role: "user", + content, + }, + ] satisfies ModelMessage[]; + } else { + const parts = yield* Effect.all( + content.map((contentItem) => + Effect.gen(function* () { + switch (contentItem.type) { + case "input_text": + return { + type: "text", + text: contentItem.text, + } satisfies TextPart; + case "input_image": + if (!contentItem.image_url) return; + return { + type: "image", + image: new URL(contentItem.image_url), + providerOptions: { + openai: { + imageDetail: contentItem.detail, + }, + }, + } satisfies ImagePart; + case "input_file": { + if (!contentItem.file_data && !contentItem.file_url) return; + + const mediaType = contentItem.file_url + ? yield* detectMimeTypeFromURL(contentItem.file_url) + : contentItem.file_data + ? yield* detectMimeTypeFromBase64EncodedString(contentItem.file_data) + : "application/octet-stream"; + + return { + type: "file", + ...(contentItem.filename ? { filename: contentItem.filename } : {}), + data: contentItem.file_url + ? new URL(contentItem.file_url) + : contentItem.file_data + ? contentItem.file_data + : "<<<<<>>>>>", + mediaType, + } satisfies FilePart; + } + } + }), + ), + ); + return [ + { + role: "user", + content: parts.filter(isNotNullable), + }, + ] satisfies ModelMessage[]; + } + } + case "assistant": { + if (typeof content === "string") { + return [ + { + role: "assistant", + content, + }, + ] satisfies ModelMessage[]; + } else { + const parts = yield* Effect.all( + content.map((contentItem) => + Effect.gen(function* () { + switch (contentItem.type) { + case "input_text": + case "output_text": + return { + type: "text", + text: contentItem.text, + } satisfies TextPart; + case "input_image": + if (!contentItem.image_url) return; + return { + type: "file", + data: new URL(contentItem.image_url), + mediaType: "image/png", + providerOptions: { + openai: { + imageDetail: contentItem.detail, + }, + }, + } satisfies FilePart; + case "input_file": { + if (!contentItem.file_data && !contentItem.file_url) return; + + const mediaType = contentItem.file_url + ? yield* detectMimeTypeFromURL(contentItem.file_url) + : contentItem.file_data + ? yield* detectMimeTypeFromBase64EncodedString(contentItem.file_data) + : "application/octet-stream"; + + return { + type: "file", + ...(contentItem.filename ? { filename: contentItem.filename } : {}), + data: contentItem.file_url + ? new URL(contentItem.file_url) + : contentItem.file_data + ? contentItem.file_data + : "<<<<<>>>>>", + mediaType, + } satisfies FilePart; + } + } + }), + ), + ); + + return [ + { + role: "assistant", + content: parts.filter(isNotNullable), + }, + ] satisfies ModelMessage[]; + } + } + } + } + case "reasoning": { + return [ + { + role: "assistant", + content: createResponseBodyInputItem.summary.map((summaryItem) => ({ + type: "reasoning", + text: summaryItem.text, + })), + providerOptions: { + openai: { + itemId: createResponseBodyInputItem.id, + reasoningEncryptedContent: createResponseBodyInputItem.encrypted_content, + }, + }, + }, + ] satisfies ModelMessage[]; + } + + case "function_call": { + const parsedArguments = yield* Effect.try({ + try: () => JSON.parse(createResponseBodyInputItem.arguments), + catch: (error) => + new AIServiceError({ + message: `Failed to parse function_call arguments: ${error}`, + cause: error, + }), + }); + + return [ + { + role: "assistant", + content: [ + { + type: "tool-call", + toolCallId: createResponseBodyInputItem.call_id, + toolName: createResponseBodyInputItem.name, + input: parsedArguments, + providerExecuted: createResponseBodyInputItem.status === "completed", + }, + ], + providerOptions: { + openai: { + itemId: createResponseBodyInputItem.id, + }, + }, + }, + ] satisfies ModelMessage[]; + } + + case "function_call_output": { + return [ + { + role: "tool", + content: [ + { + type: "tool-result", + toolCallId: createResponseBodyInputItem.call_id, + toolName: crypto.randomUUID(), + output: yield* (() => + Effect.gen(function* () { + const output = createResponseBodyInputItem.output; + + if (typeof output === "string") { + return { + type: "text", + value: output, + } satisfies ToolResultPart["output"]; + } + + return { + type: "content", + value: (yield* Effect.all( + output.map((outputItem) => + Effect.gen(function* () { + switch (outputItem.type) { + case "input_text": + return { + type: "text" as const, + text: outputItem.text, + }; + case "input_image": { + if (!outputItem.image_url) return; + return { + type: "image-url" as const, + url: outputItem.image_url, + }; + } + case "input_file": { + if (outputItem.file_data) + return { + type: "file-data" as const, + data: outputItem.file_data, + mediaType: yield* detectMimeTypeFromBase64EncodedString( + outputItem.file_data, + ), + filename: outputItem.filename ?? crypto.randomUUID(), + }; + else if (outputItem.file_url) + return { + type: "file-url" as const, + url: outputItem.file_url, + }; + else { + return; + } + } + case "input_video": { + return { + type: "file-url" as const, + url: outputItem.video_url, + }; + } + } + }), + ), + )).filter(isNotNullable), + } satisfies ToolResultPart["output"]; + }))(), + }, + ], + }, + ] satisfies ModelMessage[]; + } + default: { + return yield* Effect.fail( + new AIServiceError({ + message: `Unsupported input item type: ${(createResponseBodyInputItem as { type: string }).type}`, + }), + ); + } + } + }); + +const EFFORT_TO_BUDGET_TOKENS: Record = { + low: 1024, + medium: 4096, + high: 10000, + xhigh: 32000, +}; + +const EFFORT_TO_NOVA_REASONING_EFFORT: Record = { + low: "low", + medium: "medium", + high: "high", + xhigh: "max", +}; + +export const reasoningToProviderOptions = ( + reasoning: CreateResponseBody["reasoning"], + bedrockModelId?: string, + hasStructuredOutput?: boolean, +) => { + if (!reasoning) return undefined; + + const effort = reasoning.effort; + const summary = reasoning.summary; + + if (!effort && !summary) return undefined; + + const openaiOptions = { + ...(effort && effort !== "none" + ? { reasoningEffort: effort === "xhigh" ? "high" : effort } + : {}), + ...(summary ? { reasoningSummary: summary } : {}), + }; + + const anthropicThinkingConfig = + !effort || effort === "none" + ? ({ type: "disabled" } as const) + : ({ type: "enabled", budgetTokens: EFFORT_TO_BUDGET_TOKENS[effort] ?? 4096 } as const); + + const bedrockReasoningConfig = (() => { + if (!effort || effort === "none") return undefined; + if (hasStructuredOutput) return undefined; + + const isAnthropicModel = bedrockModelId?.includes("anthropic") ?? false; + const isAmazonModel = + bedrockModelId?.includes("amazon") ?? false; + + if (isAnthropicModel) { + return { + type: "enabled" as const, + budgetTokens: EFFORT_TO_BUDGET_TOKENS[effort] ?? 4096, + }; + } + if (isAmazonModel) { + return { + type: "enabled" as const, + maxReasoningEffort: EFFORT_TO_NOVA_REASONING_EFFORT[effort] ?? "medium", + }; + } + return undefined; + })(); + + return { + openai: openaiOptions, + anthropic: { thinking: anthropicThinkingConfig }, + ...(bedrockReasoningConfig + ? { bedrock: { reasoningConfig: bedrockReasoningConfig } } + : {}), + }; +}; + +export const toolsToCallSettings = ( + tools: CreateResponseBody["tools"], + toolChoice: CreateResponseBody["tool_choice"], +): ToolSet | undefined => { + if (!tools?.length) return undefined; + + const filteredTools = + toolChoice && + typeof toolChoice !== "string" && + toolChoice.type === "allowed_tools" + ? tools.filter((t) => + toolChoice.tools.some((allowed) => allowed.name === t.name), + ) + : tools; + + if (!filteredTools.length) return undefined; + + return Object.fromEntries( + filteredTools.map((t) => [ + t.name, + { + ...(t.description != null ? { description: t.description } : {}), + inputSchema: jsonSchema( + (t.parameters as Parameters[0]) ?? { + type: "object" as const, + }, + ), + ...(t.strict != null ? { strict: t.strict } : {}), + }, + ]), + ) as ToolSet; +}; + +export const toolChoiceToCallSettings = ( + toolChoice: CreateResponseBody["tool_choice"], +) => { + if (!toolChoice) return undefined; + if (typeof toolChoice === "string") return toolChoice; + if (toolChoice.type === "function") + return { type: "tool" as const, toolName: toolChoice.name }; + return toolChoice.mode ?? ("auto" as const); +}; + +export const textFormatToOutput = ( + text: CreateResponseBody["text"], +) => { + const format = text?.format; + if (!format || format.type === "text") return undefined; + return Output.object({ + schema: jsonSchema( + (format.schema as Parameters[0]) ?? { + type: "object" as const, + }, + ), + ...(format.name != null ? { name: format.name } : {}), + ...(format.description != null ? { description: format.description } : {}), + }); +}; diff --git a/packages/api_platform/src/services/ai/consts.ts b/packages/api_platform/src/services/ai/consts.ts index 1938ce0..4bc01db 100644 --- a/packages/api_platform/src/services/ai/consts.ts +++ b/packages/api_platform/src/services/ai/consts.ts @@ -11,7 +11,3 @@ export const DEFAULT_STORE: ResponseResource["store"] = true; export const DEFAULT_BACKGROUND: ResponseResource["background"] = false; export const DEFAULT_SERVICE_TIER: ResponseResource["service_tier"] = "default"; -export const MOCK_INPUT_TOKENS = 1000; -export const MOCK_OUTPUT_TOKENS = 100; -export const MOCK_REASONING_TOKENS = 0; -export const MOCK_CACHED_TOKENS = 0; diff --git a/packages/api_platform/src/services/ai/error-to-resource.ts b/packages/api_platform/src/services/ai/error-to-resource.ts new file mode 100644 index 0000000..7f10f02 --- /dev/null +++ b/packages/api_platform/src/services/ai/error-to-resource.ts @@ -0,0 +1,34 @@ +import type { APICallError } from "ai"; +import { Effect } from "effect"; +import type { CreateResponseBody, ResponseResource } from "common"; +import type { ResolvedResponse } from "common"; + +import { buildBaseResponse } from "./response-defaults"; + +export const errorToResponseResource = ({ + result, + createResponseBody, + createdAt, + resolvedModelAndProvider, +}: { + result: APICallError; + createResponseBody: CreateResponseBody; + createdAt: number; + resolvedModelAndProvider: ResolvedResponse; +}): Effect.Effect => + Effect.succeed({ + object: "response", + id: crypto.randomUUID(), + created_at: createdAt, + completed_at: Date.now(), + status: "error", + incomplete_details: null, + ...buildBaseResponse(createResponseBody, resolvedModelAndProvider), + output: [], + reasoning: null, + error: { + code: String(result.statusCode ?? 500), + message: result.message, + }, + usage: null, + } satisfies ResponseResource); diff --git a/packages/api_platform/src/services/ai/field-resolvers.ts b/packages/api_platform/src/services/ai/field-resolvers.ts new file mode 100644 index 0000000..e794b23 --- /dev/null +++ b/packages/api_platform/src/services/ai/field-resolvers.ts @@ -0,0 +1,40 @@ +import type { CreateResponseBody, ResponseResource } from "common"; + +export const resolveToolChoice = ( + tc: CreateResponseBody["tool_choice"], +): ResponseResource["tool_choice"] => { + if (!tc) return "none"; + if (typeof tc === "string") return tc; + if (tc.type === "function") return tc; + return { ...tc, mode: tc.mode ?? "auto" }; +}; + +export const resolveTools = (tools: CreateResponseBody["tools"]): ResponseResource["tools"] => + tools?.map((t) => ({ + type: "function" as const, + name: t.name, + description: t.description ?? null, + parameters: t.parameters ?? null, + strict: t.strict ?? null, + })) ?? []; + +export const resolveReasoning = ( + reasoning: CreateResponseBody["reasoning"], +): ResponseResource["reasoning"] => + reasoning ? { effort: reasoning.effort ?? null, summary: reasoning.summary ?? null } : null; + +export const resolveTextFormat = ( + text: CreateResponseBody["text"], +): ResponseResource["text"] => { + const format = text?.format; + if (!format || format.type === "text") return { format: { type: "text" as const } }; + return { + format: { + type: "json_schema" as const, + name: format.name ?? "json_schema", + description: format.description ?? null, + schema: format.schema ?? null, + strict: format.strict ?? false, + }, + }; +}; diff --git a/packages/api_platform/src/services/ai/index.ts b/packages/api_platform/src/services/ai/index.ts index 22bd96d..4d6550c 100644 --- a/packages/api_platform/src/services/ai/index.ts +++ b/packages/api_platform/src/services/ai/index.ts @@ -1,22 +1,26 @@ -import type { CreateResponseBody, ResolvedResponse, Providers } from "common"; -import type { ProviderModelPair } from "resolver/src/types"; +import type { CreateResponseBody, ResolvedResponse, ResponseResource, Providers } from "common"; +import type { Transaction } from "ledger"; import { AISDKError, type TextStreamPart, type ToolSet } from "ai"; import { APICallError, generateText, streamText } from "ai"; import { Effect, Data, Either } from "effect"; +import { LedgerService } from "ledger"; +import { ResolverService } from "resolver"; + +import type { RequestParams } from "../request-context"; import * as CredentialsService from "../credentials"; import * as pmrService from "../pmr"; -import { buildLanguageModelFromResolvedModelAndProvider } from "./buildLanguageModelFromResolvedModelAndProvider"; -import { convertAISdkGenerateTextResultToResponseResource } from "./convertAISdkGenerateTextResultToResponseResource"; -import { convertAPICallErrorToResponseResource } from "./convertAPICallErrorToResponseResource"; +import { buildLanguageModel } from "./model-factory"; +import { resultToResponseResource } from "./result-to-resource"; +import { errorToResponseResource } from "./error-to-resource"; import { - convertCreateResponseBodyInputFieldToCallSettingsMessages, - convertCreateResponseBodyToolsToCallSettingsTools, - convertCreateResponseBodyToolChoiceToCallSettingsToolChoice, - convertCreateResponseBodyTextFormatToCallSettingsOutput, - convertCreateResponseBodyReasoningToProviderOptions, -} from "./responseFieldsToAISDKGenerateTextCallSettingsAdapters"; + inputToMessages, + toolsToCallSettings, + toolChoiceToCallSettings, + textFormatToOutput, + reasoningToProviderOptions, +} from "./adapters"; export class AIServiceError extends Data.TaggedError("AIServiceError")<{ cause?: unknown; @@ -25,131 +29,212 @@ export class AIServiceError extends Data.TaggedError("AIServiceError")<{ type StreamResult = ReturnType; -export const execute = ( - createResponseBody: CreateResponseBody, +const prepareCallOptions = ( userId: string, - userProviders: readonly string[], - fallbackProviderModelPair: ProviderModelPair, - analysisTarget: string, + createResponseBody: CreateResponseBody, + resolvedModelAndProvider: ResolvedResponse, ) => Effect.gen(function* () { - const requestedModel = createResponseBody.model; - if (!requestedModel) { - // XXX: THIS SHOULD BE HANDLED BY ROUTE VALIDATION, BUT JUST IN CASE TO SATISFY TYPESCRIPT - return yield* new AIServiceError({ - message: "`model` field is required or should not be empty", - }); - } - const resolvedModelAndProvidersResult = yield* Effect.either( - pmrService.resolve(createResponseBody, userId, [...userProviders], analysisTarget), + const credentials = yield* CredentialsService.getCredentials( + userId, + resolvedModelAndProvider.provider as Providers, + ).pipe( + Effect.catchTag("CredentialsError", (err) => + Effect.fail(new AIServiceError({ cause: err, message: err.message })), + ), + ); + + const languageModel = yield* buildLanguageModel(resolvedModelAndProvider, credentials); + + const messages = yield* inputToMessages(createResponseBody); + + const tools = toolsToCallSettings( + createResponseBody.tools, + createResponseBody.tool_choice, + ); + + const toolChoice = toolChoiceToCallSettings(createResponseBody.tool_choice); + + const outputFormat = textFormatToOutput(createResponseBody.text); + + const hasStructuredOutput = createResponseBody.text?.format?.type === "json_schema"; + + const providerOptions = reasoningToProviderOptions( + createResponseBody.reasoning, + resolvedModelAndProvider.model, + hasStructuredOutput, ); - if (Either.isLeft(resolvedModelAndProvidersResult)) { - if (fallbackProviderModelPair) { - return yield* callLanguageModel(userId, createResponseBody, fallbackProviderModelPair); + const baseOptions = { + model: languageModel, + messages, + ...(createResponseBody.max_output_tokens && { + maxOutputTokens: createResponseBody.max_output_tokens, + }), + ...(createResponseBody.top_p && { topP: createResponseBody.top_p }), + ...(createResponseBody.temperature && { temperature: createResponseBody.temperature }), + ...(createResponseBody.presence_penalty && { + presencePenalty: createResponseBody.presence_penalty, + }), + ...(createResponseBody.frequency_penalty && { + frequencyPenalty: createResponseBody.frequency_penalty, + }), + }; + + return { baseOptions, tools, toolChoice, outputFormat, providerOptions }; + }); + +const resolveProviders = ( + createResponseBody: CreateResponseBody, + params: RequestParams, +) => + Effect.gen(function* () { + const resolvedResult = yield* Effect.either( + pmrService.resolve( + createResponseBody, + params.userId, + [...params.userProviders], + params.analysisTarget, + ), + ); + + if (Either.isLeft(resolvedResult)) { + if (params.fallbackProviderModelPair) { + return { + pairs: [ + { + ...params.fallbackProviderModelPair, + category: null as string | null, + }, + ] as ResolvedResponse[], + resolutionLatencyMs: 0, + }; } return yield* new AIServiceError({ message: "Model resolution failed and no fallback provider is configured", }); } - const resolvedModelAndProviders = resolvedModelAndProvidersResult.right; - - for (const resolvedModelAndProvider of resolvedModelAndProviders) { - const result = yield* Effect.either( - callLanguageModel(userId, createResponseBody, resolvedModelAndProvider), - ); + return resolvedResult.right; + }); - if (Either.isRight(result)) { - const response = result.right; - if (response.error === null) { - return response; - } - } +const tryProviders = ( + providers: readonly ResolvedResponse[], + fn: (provider: ResolvedResponse) => Effect.Effect, + fallback: ResolvedResponse | null, +): Effect.Effect => + Effect.gen(function* () { + for (const provider of providers) { + const result = yield* Effect.either(fn(provider)); + if (Either.isRight(result)) return result.right; } - if (fallbackProviderModelPair) { - return yield* callLanguageModel(userId, createResponseBody, fallbackProviderModelPair); + if (fallback) { + return yield* fn(fallback); } return yield* new AIServiceError({ - message: "`model` field is required or should not be empty", + message: "All providers failed and no fallback is available", }); }); -export const executeStream = ( - createResponseBody: CreateResponseBody, - userId: string, - userProviders: readonly string[], - fallbackProviderModelPair: ProviderModelPair, - analysisTarget: string, -) => +export const execute = (body: CreateResponseBody, params: RequestParams) => Effect.gen(function* () { - const requestedModel = createResponseBody.model; - if (!requestedModel) { - // XXX: THIS SHOULD BE HANDLED BY ROUTE VALIDATION, BUT JUST IN CASE TO SATISFY TYPESCRIPT + const ledgerService = yield* LedgerService; + const resolverService = yield* ResolverService; + + if (!body.model) { return yield* new AIServiceError({ message: "`model` field is required or should not be empty", }); } - const resolvedModelAndProvidersResult = yield* Effect.either( - pmrService.resolve(createResponseBody, userId, [...userProviders], analysisTarget), + const { pairs, resolutionLatencyMs } = yield* resolveProviders(body, params); + + const fallback = params.fallbackProviderModelPair + ? { ...params.fallbackProviderModelPair, category: null as string | null } + : null; + + return yield* tryProviders( + pairs, + (resolvedModelAndProvider) => + Effect.gen(function* () { + const llmStartedAt = Date.now(); + const response = yield* callLanguageModel(params.userId, body, resolvedModelAndProvider); + const totalLatencyMs = Date.now() - llmStartedAt; + + const cost = yield* resolverService + .getCostForModel( + `${resolvedModelAndProvider.provider}/${resolvedModelAndProvider.model}`, + ) + .pipe(Effect.catchAll(() => Effect.succeed(null))); + + yield* ledgerService + .insertTransaction( + buildTransaction({ + resolvedModelAndProvider, + resolutionLatencyMs, + userId: params.userId, + isStreaming: false, + response, + cost, + totalLatencyMs, + ttftMs: null, + }), + ) + .pipe(Effect.ignore); + + if (response.error !== null) { + return yield* new AIServiceError({ + message: response.error.message, + }); + } + + return response; + }), + fallback, ); + }); - if (Either.isLeft(resolvedModelAndProvidersResult)) { - if (fallbackProviderModelPair) { - return yield* callLanguageModelStreaming( - userId, - createResponseBody, - fallbackProviderModelPair, - ); - } +export const executeStream = (body: CreateResponseBody, params: RequestParams) => + Effect.gen(function* () { + if (!body.model) { return yield* new AIServiceError({ - message: "Model resolution failed and no fallback provider is configured", + message: "`model` field is required or should not be empty", }); } - const resolvedModelAndProviders = resolvedModelAndProvidersResult.right; + const { pairs, resolutionLatencyMs } = yield* resolveProviders(body, params); - for (const resolvedModelAndProvider of resolvedModelAndProviders) { - const result = yield* Effect.either( + const fallback = params.fallbackProviderModelPair + ? { ...params.fallbackProviderModelPair, category: null as string | null } + : null; + + return yield* tryProviders( + pairs, + (resolvedModelAndProvider) => Effect.gen(function* () { + const llmStartedAt = Date.now(); const callResult = yield* callLanguageModelStreaming( - userId, - createResponseBody, + params.userId, + body, resolvedModelAndProvider, ); const probedFullStream = yield* probeStream(callResult.result); - // Assign fullStream directly instead of spreading, because StreamTextResult - // exposes properties like totalUsage as prototype getters which are lost by spread. Object.defineProperty(callResult.result, "fullStream", { value: probedFullStream }); - return callResult; + const ttftMs = Date.now() - llmStartedAt; + return { ...callResult, resolutionLatencyMs, llmStartedAt, ttftMs }; }), - ); - - if (Either.isRight(result)) { - // const response = result.right; - // if (response.result === null) { - return result.right; - // } - } - } - - if (fallbackProviderModelPair) { - return yield* callLanguageModelStreaming( - userId, - createResponseBody, - fallbackProviderModelPair, - ); - } - - return yield* new AIServiceError({ - message: "`model` field is required or should not be empty", - }); + fallback + ? { + ...fallback, + category: null as string | null, + } + : null, + ); }); const callLanguageModel = ( @@ -158,66 +243,14 @@ const callLanguageModel = ( resolvedModelAndProvider: ResolvedResponse, ) => Effect.gen(function* () { - const credentials = yield* CredentialsService.getCredentials( - userId, - resolvedModelAndProvider.provider as Providers, - ).pipe( - Effect.catchTag("CredentialsError", (err) => - Effect.fail(new AIServiceError({ cause: err, message: err.message })), - ), - ); - - const languageModel = yield* buildLanguageModelFromResolvedModelAndProvider( - resolvedModelAndProvider, - credentials, - ); - - const messages = - yield* convertCreateResponseBodyInputFieldToCallSettingsMessages(createResponseBody); - - const tools = convertCreateResponseBodyToolsToCallSettingsTools( - createResponseBody.tools, - createResponseBody.tool_choice, - ); - - const toolChoice = convertCreateResponseBodyToolChoiceToCallSettingsToolChoice( - createResponseBody.tool_choice, - ); - - const outputFormat = convertCreateResponseBodyTextFormatToCallSettingsOutput( - createResponseBody.text, - ); - - const hasStructuredOutput = createResponseBody.text?.format?.type === "json_schema"; - - const providerOptions = convertCreateResponseBodyReasoningToProviderOptions( - createResponseBody.reasoning, - resolvedModelAndProvider.model, - hasStructuredOutput, - ); - - // NOTE: parallel_tool_calls, max_tool_calls, prompt_cache_key, truncation, top_logProbs - const generateTextOptions = { - model: languageModel, - messages, - ...(createResponseBody.max_output_tokens && { - maxOutputTokens: createResponseBody.max_output_tokens, - }), - ...(createResponseBody.top_p && { topP: createResponseBody.top_p }), - ...(createResponseBody.temperature && { temperature: createResponseBody.temperature }), - ...(createResponseBody.presence_penalty && { - presencePenalty: createResponseBody.presence_penalty, - }), - ...(createResponseBody.frequency_penalty && { - frequencyPenalty: createResponseBody.frequency_penalty, - }), - }; + const { baseOptions, tools, toolChoice, outputFormat, providerOptions } = + yield* prepareCallOptions(userId, createResponseBody, resolvedModelAndProvider); const result = yield* Effect.either( Effect.tryPromise({ try: (abortSignal) => generateText({ - ...generateTextOptions, + ...baseOptions, abortSignal, ...(tools ? { tools } : {}), ...(toolChoice ? { toolChoice } : {}), @@ -241,7 +274,7 @@ const callLanguageModel = ( return yield* errorValue; } - return yield* convertAPICallErrorToResponseResource({ + return yield* errorToResponseResource({ result: errorValue, createResponseBody, createdAt: Date.now(), @@ -249,7 +282,7 @@ const callLanguageModel = ( }); } - return yield* convertAISdkGenerateTextResultToResponseResource({ + return yield* resultToResponseResource({ result: result.right, createResponseBody, createdAt: Date.now(), @@ -263,67 +296,13 @@ const callLanguageModelStreaming = ( resolvedModelAndProvider: ResolvedResponse, ) => Effect.gen(function* () { - const credentials = yield* CredentialsService.getCredentials( - userId, - resolvedModelAndProvider.provider as Providers, - ).pipe( - Effect.catchTag("CredentialsError", (err) => - Effect.fail(new AIServiceError({ cause: err, message: err.message })), - ), - ); - - const languageModel = yield* buildLanguageModelFromResolvedModelAndProvider( - resolvedModelAndProvider, - credentials, - ); - - const messages = - yield* convertCreateResponseBodyInputFieldToCallSettingsMessages(createResponseBody); - - const tools = convertCreateResponseBodyToolsToCallSettingsTools( - createResponseBody.tools, - createResponseBody.tool_choice, - ); - - const toolChoice = convertCreateResponseBodyToolChoiceToCallSettingsToolChoice( - createResponseBody.tool_choice, - ); - - const outputFormat = convertCreateResponseBodyTextFormatToCallSettingsOutput( - createResponseBody.text, - ); - - const hasStructuredOutput = createResponseBody.text?.format?.type === "json_schema"; - - const providerOptions = convertCreateResponseBodyReasoningToProviderOptions( - createResponseBody.reasoning, - resolvedModelAndProvider.model, - hasStructuredOutput, - ); - - void outputFormat; - - // NOTE: parallel_tool_calls, max_tool_calls, prompt_cache_key, truncation, top_logProbs - const generateTextOptions = { - model: languageModel, - messages, - ...(createResponseBody.max_output_tokens && { - maxOutputTokens: createResponseBody.max_output_tokens, - }), - ...(createResponseBody.top_p && { topP: createResponseBody.top_p }), - ...(createResponseBody.temperature && { temperature: createResponseBody.temperature }), - ...(createResponseBody.presence_penalty && { - presencePenalty: createResponseBody.presence_penalty, - }), - ...(createResponseBody.frequency_penalty && { - frequencyPenalty: createResponseBody.frequency_penalty, - }), - }; + const { baseOptions, tools, toolChoice, providerOptions } = + yield* prepareCallOptions(userId, createResponseBody, resolvedModelAndProvider); const stream = yield* Effect.try({ try: () => streamText({ - ...generateTextOptions, + ...baseOptions, ...(tools ? { tools } : {}), ...(toolChoice ? { toolChoice } : {}), ...(providerOptions ? { providerOptions } : {}), @@ -390,3 +369,56 @@ const probeStream = (streamResult: StreamResult) => }, catch: (error) => new AIServiceError({ cause: error, message: "Stream failed on first chunk" }), }); + +export const buildTransaction = (opts: { + resolvedModelAndProvider: ResolvedResponse; + resolutionLatencyMs: number; + userId: string; + isStreaming: boolean; + response: ResponseResource | null; + cost: { input: number; output: number } | null; + totalLatencyMs: number | null; + ttftMs: number | null; +}): Transaction => { + const { + resolvedModelAndProvider, + resolutionLatencyMs, + userId, + isStreaming, + response, + cost, + totalLatencyMs, + ttftMs, + } = opts; + + const usage = response?.usage ?? null; + const inputTokens = usage?.input_tokens ?? null; + const outputTokens = usage?.output_tokens ?? null; + const reasoningTokens = usage?.output_tokens_details?.reasoning_tokens ?? null; + + const httpStatusCode = response?.error + ? parseInt(response.error.code, 10) || null + : response ? 200 : null; + const errorType = response?.error?.message ?? null; + + return { + timestamp: new Date(), + request_id: crypto.randomUUID(), + provider: resolvedModelAndProvider.provider, + model: resolvedModelAndProvider.model, + category: resolvedModelAndProvider.category, + resolution_latency_ms: resolutionLatencyMs, + ttft_ms: ttftMs, + total_latency_ms: totalLatencyMs, + input_tokens: inputTokens, + reasoning_tokens: reasoningTokens, + output_tokens: outputTokens, + input_cost_usd: inputTokens != null && cost ? inputTokens * cost.input : null, + reasoning_cost_usd: reasoningTokens != null && cost ? reasoningTokens * cost.input : null, + output_cost_usd: outputTokens != null && cost ? outputTokens * cost.output : null, + http_status_code: httpStatusCode, + error_type: errorType, + is_streaming: isStreaming, + user_id: userId, + }; +}; diff --git a/packages/api_platform/src/services/ai/messages-to-output.ts b/packages/api_platform/src/services/ai/messages-to-output.ts new file mode 100644 index 0000000..2957b1c --- /dev/null +++ b/packages/api_platform/src/services/ai/messages-to-output.ts @@ -0,0 +1,357 @@ +import type { generateText } from "ai"; +import { Effect } from "effect"; +import type { FunctionCall, FunctionCallOutput, ItemField, Message, ReasoningBody } from "common"; +import { isValidUrl } from "../../utils"; + +export const messagesToOutput = ( + result: Awaited>, +) => { + type FunctionCallOutputContentPart = Exclude[number]; + type ToolResultContentPart = + | { type: "text"; text: string } + | { type: "media"; mediaType: string; data: string } + | { type: "file-url"; url: string } + | { type: "file-data"; mediaType: string; data: string } + | { type: "image-data"; data: string } + | { type: "image-url"; url: string } + | { type: "file-id"; fileId: string } + | { type: "image-file-id"; fileId: string } + | { type: "custom"; providerOptions: unknown }; + + type ToolResultOutput = + | { type: "text"; value: string } + | { type: "execution-denied"; reason?: string } + | { type: "error-json"; value: unknown } + | { type: "json"; value: unknown } + | { type: "error-text"; value: string } + | { type: "content"; value: ToolResultContentPart[] }; + + const convertToolResultContentPart = ( + outputValue: ToolResultContentPart, + index: number, + ): FunctionCallOutputContentPart => { + switch (outputValue.type) { + case "text": + return { + type: "input_text", + text: outputValue.text, + } satisfies FunctionCallOutputContentPart; + case "media": { + if (outputValue.mediaType.startsWith("image/")) { + return { + type: "input_image", + image_url: isValidUrl(outputValue.data) + ? outputValue.data + : `data:image/png;base64,${outputValue.data}`, + detail: "auto", + } satisfies FunctionCallOutputContentPart; + } + return { + type: "input_file", + filename: `file-${index}`, + file_url: isValidUrl(outputValue.data) + ? outputValue.data + : `data:${outputValue.mediaType};base64,${outputValue.data}`, + } satisfies FunctionCallOutputContentPart; + } + case "file-url": + return { + type: "input_file", + filename: `file-${index}`, + file_url: outputValue.url, + } satisfies FunctionCallOutputContentPart; + case "file-data": + return { + type: "input_file", + filename: `file-${index}`, + file_url: `data:${outputValue.mediaType};base64,${outputValue.data}`, + } satisfies FunctionCallOutputContentPart; + case "image-data": + return { + type: "input_image", + image_url: `data:image/png;base64,${outputValue.data}`, + detail: "auto", + } satisfies FunctionCallOutputContentPart; + case "image-url": + return { + type: "input_image", + image_url: outputValue.url, + detail: "auto", + } satisfies FunctionCallOutputContentPart; + case "file-id": + case "image-file-id": + return { + type: "input_text", + text: JSON.stringify(outputValue.fileId), + } satisfies FunctionCallOutputContentPart; + case "custom": + return { + type: "input_text", + text: JSON.stringify(outputValue.providerOptions), + } satisfies FunctionCallOutputContentPart; + } + }; + + const convertToolResultOutput = (output: ToolResultOutput): FunctionCallOutput["output"] => { + switch (output.type) { + case "text": + return output.value; + case "execution-denied": + return `Tool execution denied "${output.reason || "NO REASON PROVIDED"}"`; + case "error-json": + case "json": + return JSON.stringify(output.value); + case "error-text": + return `Tool execution resulted in error: "${output.value}"`; + case "content": { + return output.value.map((outputValue, index) => + convertToolResultContentPart(outputValue, index), + ) satisfies FunctionCallOutput["output"]; + } + } + }; + + const getItemIdFromProviderOptions = (providerOptions: unknown, fallbackId: string) => { + if (!providerOptions || typeof providerOptions !== "object") { + return fallbackId; + } + + const itemEntry = Object.entries(providerOptions).find( + ([_, metadata]) => typeof metadata === "object" && metadata !== null && "itemId" in metadata, + )?.[1] as { itemId?: string } | undefined; + + return itemEntry?.itemId ?? fallbackId; + }; + + const getEncryptedContentFromProviderOptions = ( + providerOptions: unknown, + ): string | undefined => { + if (!providerOptions || typeof providerOptions !== "object") return undefined; + const entry = Object.entries(providerOptions).find( + ([_, metadata]) => + typeof metadata === "object" && + metadata !== null && + "reasoningEncryptedContent" in metadata, + )?.[1] as { reasoningEncryptedContent?: string } | undefined; + return entry?.reasoningEncryptedContent; + }; + + const messagesAsOutput: ItemField[] = result.response.messages.flatMap((message, indx) => { + const messageRole = message.role as Message["role"] | "tool"; + + switch (messageRole) { + case "tool": { + type ToolMessageContentItem = + | { + type: "tool-result"; + toolCallId: string; + output: ToolResultOutput; + providerOptions?: unknown; + } + | { type: "tool-approval-response"; approvalId: string }; + + const content = Array.isArray(message.content) + ? (message.content as ToolMessageContentItem[]) + : []; + + return content.flatMap((c) => { + switch (c.type) { + case "tool-result": { + return [ + { + type: "function_call_output", + id: getItemIdFromProviderOptions(c.providerOptions, c.toolCallId), + status: "completed", + output: convertToolResultOutput(c.output as ToolResultOutput), + call_id: c.toolCallId, + } satisfies FunctionCallOutput, + ]; + } + case "tool-approval-response": { + return [{ type: "function_call", call_id: c.approvalId } as ItemField]; + } + default: + return []; + } + }); + } + case "user": + case "system": + case "developer": + return []; + case "assistant": { + const content = message.content; + if (typeof content === "string") { + return { + type: "message", + id: `message-${indx}`, + status: "completed", + role: "assistant", + content: [ + { + type: "output_text", + text: content, + annotations: [], + logprobs: [], + }, + ], + } satisfies Message; + } + return content.flatMap((contentItem): ItemField[] => { + switch (contentItem.type) { + case "text": { + return [ + { + type: "message", + id: getItemIdFromProviderOptions(contentItem.providerOptions, `message-${indx}`), + status: "completed", + role: "assistant", + content: [ + { + type: "output_text", + text: contentItem.text, + annotations: [], + logprobs: [], + }, + ], + } satisfies Message, + ]; + } + case "reasoning": { + const encryptedContent = getEncryptedContentFromProviderOptions( + contentItem.providerOptions, + ); + return [ + { + type: "reasoning", + id: getItemIdFromProviderOptions( + contentItem.providerOptions, + `reasoning-${indx}`, + ), + summary: [ + { + type: "summary_text", + text: contentItem.text, + }, + ], + ...(encryptedContent ? { encrypted_content: encryptedContent } : {}), + } satisfies ReasoningBody, + ]; + } + case "file": { + if (contentItem.mediaType.startsWith("image/")) { + return [ + { + type: "message", + id: getItemIdFromProviderOptions( + contentItem.providerOptions, + `message-${indx}`, + ), + status: "completed", + role: "assistant", + content: [ + { + type: "input_image", + image_url: + contentItem.data instanceof URL + ? String(contentItem.data) + : `data:${contentItem.mediaType};base64,${contentItem.data}`, + detail: "auto", + }, + ], + } satisfies Message, + ]; + } + + if (contentItem.mediaType.startsWith("video/")) { + return [ + { + type: "message", + id: getItemIdFromProviderOptions( + contentItem.providerOptions, + `message-${indx}`, + ), + status: "completed", + role: "assistant", + content: [ + { + type: "input_video", + video_url: + contentItem.data instanceof URL + ? String(contentItem.data) + : `data:${contentItem.mediaType};base64,${contentItem.data}`, + }, + ], + } satisfies Message, + ]; + } + + return [ + { + type: "message", + id: `message-${indx}`, + status: "completed", + role: "assistant", + content: [ + { + type: "input_file", + file_url: + contentItem.data instanceof URL + ? String(contentItem.data) + : `data:${contentItem.mediaType};base64,${contentItem.data}`, + }, + ], + } satisfies Message, + ]; + } + case "tool-call": { + return [ + { + type: "function_call", + id: getItemIdFromProviderOptions( + contentItem.providerOptions, + contentItem.toolCallId, + ), + status: "completed", + call_id: contentItem.toolCallId, + name: contentItem.toolName, + arguments: + typeof contentItem.input === "string" + ? contentItem.input + : JSON.stringify(contentItem.input), + } satisfies FunctionCall, + ]; + } + case "tool-approval-request": { + return [ + { + type: "function_call", + id: contentItem.approvalId, + status: "in_progress", + call_id: contentItem.toolCallId, + name: `tool-approval-${contentItem.toolCallId}`, + arguments: "", + } satisfies FunctionCall, + ]; + } + case "tool-result": { + return [ + { + type: "function_call_output", + id: contentItem.toolCallId, + status: "completed", + output: convertToolResultOutput(contentItem.output as ToolResultOutput), + call_id: "", + } satisfies FunctionCallOutput, + ]; + } + default: + return []; + } + }); + } + } + }); + + return Effect.succeed(messagesAsOutput); +}; diff --git a/packages/api_platform/src/services/ai/model-factory.ts b/packages/api_platform/src/services/ai/model-factory.ts new file mode 100644 index 0000000..e1e44d3 --- /dev/null +++ b/packages/api_platform/src/services/ai/model-factory.ts @@ -0,0 +1,24 @@ +import { Effect } from "effect"; +import { Providers, SUPPORTED_PROVIDERS, type ProviderCredentials } from "common"; +import type { ResolvedResponse } from "common"; +import { AIServiceError } from "."; +import { getProviderEntry } from "../provider-registry"; + +export const buildLanguageModel = ( + resolved: ResolvedResponse, + credentials: ProviderCredentials, +) => + Effect.gen(function* () { + const { provider, model } = resolved; + + if (!provider || !model || !SUPPORTED_PROVIDERS.includes(provider as Providers)) { + return yield* Effect.fail( + new AIServiceError({ + message: !provider ? "Empty provider resolved" : `Unsupported provider: ${provider}`, + }), + ); + } + + const entry = getProviderEntry(provider as Providers); + return entry.createClient(credentials)(model); + }); diff --git a/packages/api_platform/src/services/ai/response-defaults.ts b/packages/api_platform/src/services/ai/response-defaults.ts new file mode 100644 index 0000000..b5963c3 --- /dev/null +++ b/packages/api_platform/src/services/ai/response-defaults.ts @@ -0,0 +1,50 @@ +import type { CreateResponseBody, ResponseResource } from "common"; + +import { + DEFAULT_BACKGROUND, + DEFAULT_FREQUENCY_PENALTY, + DEFAULT_PARALLEL_TOOL_CALLS, + DEFAULT_PRESENCE_PENALTY, + DEFAULT_SERVICE_TIER, + DEFAULT_STORE, + DEFAULT_TEMPERATURE, + DEFAULT_TOP_LOGPROBS, + DEFAULT_TOP_P, + DEFAULT_TRUNCATION, +} from "./consts"; +import { resolveTools, resolveToolChoice, resolveTextFormat } from "./field-resolvers"; + +/** + * Builds the shared base fields of a ResponseResource from a request body. + * Used by result-to-resource, error-to-resource, and streaming skeleton to avoid + * duplicating 20+ default field assignments. + */ +export const buildBaseResponse = ( + req: CreateResponseBody, + resolved: { provider: string; model: string }, +): Omit => ({ + model: `${resolved.provider}/${resolved.model}`, + previous_response_id: req.previous_response_id ?? null, + instructions: req.instructions ?? null, + text: resolveTextFormat(req.text), + top_logprobs: req.top_logprobs ?? DEFAULT_TOP_LOGPROBS, + reasoning: req.reasoning + ? { effort: req.reasoning.effort ?? null, summary: req.reasoning.summary ?? null } + : null, + tools: resolveTools(req.tools), + tool_choice: resolveToolChoice(req.tool_choice), + truncation: req.truncation ?? DEFAULT_TRUNCATION, + parallel_tool_calls: req.parallel_tool_calls ?? DEFAULT_PARALLEL_TOOL_CALLS, + top_p: req.top_p ?? DEFAULT_TOP_P, + presence_penalty: req.presence_penalty ?? DEFAULT_PRESENCE_PENALTY, + frequency_penalty: req.frequency_penalty ?? DEFAULT_FREQUENCY_PENALTY, + temperature: req.temperature ?? DEFAULT_TEMPERATURE, + max_output_tokens: req.max_output_tokens ?? null, + max_tool_calls: req.max_tool_calls ?? null, + store: req.store ?? DEFAULT_STORE, + background: req.background ?? DEFAULT_BACKGROUND, + service_tier: req.service_tier ?? DEFAULT_SERVICE_TIER, + metadata: req.metadata ?? null, + safety_identifier: req.safety_identifier ?? null, + prompt_cache_key: req.prompt_cache_key ?? null, +}); diff --git a/packages/api_platform/src/services/ai/result-to-resource.ts b/packages/api_platform/src/services/ai/result-to-resource.ts new file mode 100644 index 0000000..9a40f26 --- /dev/null +++ b/packages/api_platform/src/services/ai/result-to-resource.ts @@ -0,0 +1,43 @@ +import { Effect } from "effect"; +import type { CreateResponseBody, ResponseResource } from "common"; +import type { generateText } from "ai"; +import type { ResolvedResponse } from "common"; + +import { messagesToOutput } from "./messages-to-output"; +import { buildBaseResponse } from "./response-defaults"; + +export const resultToResponseResource = ({ + result, + createdAt, + resolvedModelAndProvider, + createResponseBody, +}: { + result: Awaited>; + createdAt: number; + resolvedModelAndProvider: ResolvedResponse; + createResponseBody: CreateResponseBody; +}) => + Effect.gen(function* () { + return { + object: "response", + id: crypto.randomUUID(), + created_at: createdAt, + completed_at: Date.now(), + status: "completed", + incomplete_details: null, + ...buildBaseResponse(createResponseBody, resolvedModelAndProvider), + output: yield* messagesToOutput(result), + error: null, + usage: { + input_tokens: result.totalUsage.inputTokens ?? 0, + output_tokens: result.totalUsage.outputTokens ?? 0, + input_tokens_details: { + cached_tokens: result.totalUsage.inputTokenDetails?.cacheWriteTokens ?? 0, + }, + output_tokens_details: { + reasoning_tokens: result.totalUsage.outputTokenDetails?.reasoningTokens ?? 0, + }, + total_tokens: result.totalUsage.totalTokens ?? 0, + }, + } satisfies ResponseResource; + }); diff --git a/packages/api_platform/src/services/ai/stream-events.ts b/packages/api_platform/src/services/ai/stream-events.ts new file mode 100644 index 0000000..c6d06be --- /dev/null +++ b/packages/api_platform/src/services/ai/stream-events.ts @@ -0,0 +1,445 @@ +import type { TextStreamPart, ToolSet } from "ai"; +import type { FunctionCall, ItemField, Message, ReasoningBody, StreamingEvent } from "common"; + +export type AccumulatedState = { + sequenceNumber: number; + currentItemId: string | null; + outputIndex: number; + contentIndex: number; + accumulatedText: string; + outputItems: ItemField[]; + currentItemType: "text" | "reasoning" | "tool-input" | null; + currentToolCallId: string | null; + currentToolName: string | null; +}; + +type ToolInputPart = { + toolCallId?: string; + toolName?: string; + delta?: string; + text?: string; + argsTextDelta?: string; +}; + +const getTextDelta = (part: ToolInputPart): string => + part.delta ?? part.text ?? part.argsTextDelta ?? ""; + +export const streamToEvents = ( + fullStream: AsyncIterable>, + startingSequenceNumber = 0, +) => { + let sequenceNumber = startingSequenceNumber; + let currentItemId: string | null = null; + let outputIndex = -1; + let contentIndex = 0; + let accumulatedText = ""; + let currentItemType: AccumulatedState["currentItemType"] = null; + let currentToolCallId: string | null = null; + let currentToolName: string | null = null; + const outputItems: ItemField[] = []; + + const nextSequenceNumber = () => { + const current = sequenceNumber; + sequenceNumber += 1; + return current; + }; + + const startNewItem = (type: AccumulatedState["currentItemType"]) => { + currentItemId = crypto.randomUUID(); + outputIndex += 1; + contentIndex = 0; + accumulatedText = ""; + currentItemType = type; + }; + + const finishItem = () => { + currentItemId = null; + currentItemType = null; + currentToolCallId = null; + currentToolName = null; + }; + + const getAccumulatedState = (): AccumulatedState => ({ + sequenceNumber, + currentItemId, + outputIndex, + contentIndex, + accumulatedText, + outputItems, + currentItemType, + currentToolCallId, + currentToolName, + }); + + const events = (async function* () { + for await (const part of fullStream) { + try { + switch (part.type) { + case "text-start": { + startNewItem("text"); + + if (!currentItemId) break; + + const addedEvent: StreamingEvent = { + type: "response.output_item.added", + sequence_number: nextSequenceNumber(), + output_index: outputIndex, + item: { + id: currentItemId, + status: "in_progress", + role: "assistant", + content: [], + }, + }; + + const contentAddedEvent: StreamingEvent = { + type: "response.content_part.added", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: contentIndex, + part: { + type: "output_text", + text: "", + annotations: [], + logprobs: [], + }, + }; + + contentIndex += 1; + yield [addedEvent, contentAddedEvent]; + break; + } + case "text-delta": { + if (!currentItemId || currentItemType !== "text") break; + const delta = (part as ToolInputPart).text ?? ""; + accumulatedText += delta; + const deltaEvent: StreamingEvent = { + type: "response.output_text.delta", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + delta, + logprobs: [], + }; + yield [deltaEvent]; + break; + } + case "text-end": { + if (!currentItemId || currentItemType !== "text") break; + + const textDoneEvent: StreamingEvent = { + type: "response.output_text.done", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + text: accumulatedText, + logprobs: [], + }; + + const contentDoneEvent: StreamingEvent = { + type: "response.content_part.done", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + part: { + type: "output_text", + text: accumulatedText, + annotations: [], + logprobs: [], + }, + }; + + const completedMessageItem = { + id: currentItemId, + status: "completed" as const, + role: "assistant" as const, + content: [ + { + type: "output_text" as const, + text: accumulatedText, + annotations: [] as never[], + logprobs: [] as never[], + }, + ], + }; + + const outputDoneEvent: StreamingEvent = { + type: "response.output_item.done", + sequence_number: nextSequenceNumber(), + output_index: outputIndex, + item: completedMessageItem, + }; + + outputItems.push({ + type: "message", + ...completedMessageItem, + } satisfies Message); + + yield [textDoneEvent, contentDoneEvent, outputDoneEvent]; + finishItem(); + break; + } + case "reasoning-start": { + startNewItem("reasoning"); + + if (!currentItemId) break; + + const addedEvent: StreamingEvent = { + type: "response.output_item.added", + sequence_number: nextSequenceNumber(), + output_index: outputIndex, + item: { + id: currentItemId, + summary: [], + }, + }; + + const contentAddedEvent: StreamingEvent = { + type: "response.content_part.added", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: contentIndex, + part: { + type: "reasoning", + text: "", + }, + }; + + contentIndex += 1; + yield [addedEvent, contentAddedEvent]; + break; + } + case "reasoning-delta": { + if (!currentItemId || currentItemType !== "reasoning") break; + const delta = (part as ToolInputPart).text ?? ""; + accumulatedText += delta; + const deltaEvent: StreamingEvent = { + type: "response.reasoning.delta", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + delta, + }; + yield [deltaEvent]; + break; + } + case "reasoning-end": { + if (!currentItemId || currentItemType !== "reasoning") break; + + const reasoningDoneEvent: StreamingEvent = { + type: "response.reasoning.done", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + text: accumulatedText, + }; + + const contentDoneEvent: StreamingEvent = { + type: "response.content_part.done", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + part: { + type: "reasoning", + text: accumulatedText, + }, + }; + + const completedReasoningItem = { + id: currentItemId, + summary: [] as never[], + content: [{ type: "reasoning" as const, text: accumulatedText }], + }; + + const outputDoneEvent: StreamingEvent = { + type: "response.output_item.done", + sequence_number: nextSequenceNumber(), + output_index: outputIndex, + item: completedReasoningItem, + }; + + outputItems.push({ + type: "reasoning", + ...completedReasoningItem, + } satisfies ReasoningBody); + + yield [reasoningDoneEvent, contentDoneEvent, outputDoneEvent]; + finishItem(); + break; + } + case "tool-input-start": { + startNewItem("tool-input"); + const { toolCallId, toolName } = part as ToolInputPart; + currentToolCallId = toolCallId ?? null; + currentToolName = toolName ?? null; + + if (!currentItemId) break; + + const addedEvent: StreamingEvent = { + type: "response.output_item.added", + sequence_number: nextSequenceNumber(), + output_index: outputIndex, + item: { + id: currentItemId, + call_id: currentToolCallId ?? "", + name: currentToolName ?? "", + arguments: "", + status: "in_progress" as const, + }, + }; + + const contentAddedEvent: StreamingEvent = { + type: "response.content_part.added", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: contentIndex, + part: { + type: "input_text", + text: "", + }, + }; + + contentIndex += 1; + yield [addedEvent, contentAddedEvent]; + break; + } + case "tool-input-delta": { + if (!currentItemId || currentItemType !== "tool-input") break; + const delta = getTextDelta(part as ToolInputPart); + accumulatedText += delta; + const deltaEvent: StreamingEvent = { + type: "response.function_call_arguments.delta", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + delta, + }; + yield [deltaEvent]; + break; + } + case "tool-input-end": { + if (!currentItemId || currentItemType !== "tool-input") break; + + const argumentsDoneEvent: StreamingEvent = { + type: "response.function_call_arguments.done", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + arguments: accumulatedText, + }; + + const contentDoneEvent: StreamingEvent = { + type: "response.content_part.done", + sequence_number: nextSequenceNumber(), + item_id: currentItemId, + output_index: outputIndex, + content_index: Math.max(contentIndex - 1, 0), + part: { + type: "input_text", + text: accumulatedText, + }, + }; + + const completedFunctionCallItem = { + id: currentItemId, + status: "completed" as const, + call_id: currentToolCallId ?? "", + name: currentToolName ?? "", + arguments: accumulatedText, + }; + + const outputDoneEvent: StreamingEvent = { + type: "response.output_item.done", + sequence_number: nextSequenceNumber(), + output_index: outputIndex, + item: completedFunctionCallItem, + }; + + outputItems.push({ + type: "function_call", + ...completedFunctionCallItem, + } satisfies FunctionCall); + + yield [argumentsDoneEvent, contentDoneEvent, outputDoneEvent]; + finishItem(); + break; + } + case "error": { + const errorValue = (part as { error?: unknown }).error; + const errorObject = + typeof errorValue === "object" && errorValue !== null + ? (errorValue as Record) + : undefined; + const errorEvent: StreamingEvent = { + type: "error", + sequence_number: nextSequenceNumber(), + error: { + type: (typeof errorObject?.type === "string" + ? errorObject?.type + : typeof errorObject?.name === "string" + ? errorObject?.name + : "error") as string, + code: typeof errorObject?.code === "string" ? errorObject?.code : null, + message: + typeof errorObject?.message === "string" + ? errorObject?.message + : typeof errorValue === "string" + ? errorValue + : "Unknown error", + param: typeof errorObject?.param === "string" ? errorObject?.param : null, + ...(typeof errorObject?.headers === "object" && errorObject?.headers !== null + ? { headers: errorObject?.headers as Record } + : {}), + }, + }; + yield [errorEvent]; + break; + } + case "finish": { + break; + } + default: { + break; + } + } + } catch (e) { + const errorObject = + typeof e === "object" && e !== null ? (e as Record) : undefined; + const errorEvent: StreamingEvent = { + type: "error", + sequence_number: nextSequenceNumber(), + error: { + type: (typeof errorObject?.type === "string" + ? errorObject?.type + : typeof errorObject?.name === "string" + ? errorObject?.name + : "error") as string, + code: typeof errorObject?.code === "string" ? errorObject?.code : null, + message: + typeof errorObject?.message === "string" + ? errorObject?.message + : typeof e === "string" + ? e + : "Unknown error", + param: typeof errorObject?.param === "string" ? errorObject?.param : null, + ...(typeof errorObject?.headers === "object" && errorObject?.headers !== null + ? { headers: errorObject?.headers as Record } + : {}), + }, + }; + yield [errorEvent]; + } + } + })(); + + return { events, getAccumulatedState }; +}; diff --git a/packages/api_platform/src/services/classification-cache.ts b/packages/api_platform/src/services/classification-cache.ts deleted file mode 100644 index 96711f0..0000000 --- a/packages/api_platform/src/services/classification-cache.ts +++ /dev/null @@ -1,86 +0,0 @@ -import type { ResolvedResponse } from "common"; - -const DEFAULT_MAX_ENTRIES_PER_USER = 100; -const DEFAULT_TTL_MS = 60 * 60 * 1000; // 1 hour - -interface CacheEntry { - readonly result: ResolvedResponse; - readonly createdAt: number; -} - -/** - * Per-user in-memory cache for system prompt classification results. - * - * Keyed by userId -> SHA-256 hash of system prompt text -> ResolvedResponse. - * Used when analysisTarget is "per_system_prompt" to avoid re-classifying - * the same system prompt on every request. - * - * Entries are evicted after TTL or when the per-user limit is reached (LRU-style). - */ -class ClassificationCache { - private readonly cache = new Map>(); - private readonly maxEntriesPerUser: number; - private readonly ttlMs: number; - - constructor( - maxEntriesPerUser: number = DEFAULT_MAX_ENTRIES_PER_USER, - ttlMs: number = DEFAULT_TTL_MS, - ) { - this.maxEntriesPerUser = maxEntriesPerUser; - this.ttlMs = ttlMs; - } - - async get(userId: string, systemPromptText: string): Promise { - const userCache = this.cache.get(userId); - if (!userCache) return undefined; - - const hash = await this.hashText(systemPromptText); - const entry = userCache.get(hash); - if (!entry) return undefined; - - if (Date.now() - entry.createdAt > this.ttlMs) { - userCache.delete(hash); - if (userCache.size === 0) this.cache.delete(userId); - return undefined; - } - - return entry.result; - } - - async set(userId: string, systemPromptText: string, result: ResolvedResponse): Promise { - let userCache = this.cache.get(userId); - if (!userCache) { - userCache = new Map(); - this.cache.set(userId, userCache); - } - - // Evict oldest entry if at capacity - if (userCache.size >= this.maxEntriesPerUser) { - let oldestKey: string | undefined; - let oldestTime = Infinity; - for (const [key, entry] of userCache) { - if (entry.createdAt < oldestTime) { - oldestTime = entry.createdAt; - oldestKey = key; - } - } - if (oldestKey) userCache.delete(oldestKey); - } - - const hash = await this.hashText(systemPromptText); - userCache.set(hash, { result, createdAt: Date.now() }); - } - - private async hashText(text: string): Promise { - const encoder = new TextEncoder(); - const data = encoder.encode(text); - const hashBuffer = await crypto.subtle.digest("SHA-256", data); - const hashArray = new Uint8Array(hashBuffer); - return Array.from(hashArray) - .map((b) => b.toString(16).padStart(2, "0")) - .join(""); - } -} - -/** Singleton classification cache instance for the api_platform process. */ -export const classificationCache = new ClassificationCache(); diff --git a/packages/api_platform/src/services/credentials.ts b/packages/api_platform/src/services/credentials.ts index b750565..aa32592 100644 --- a/packages/api_platform/src/services/credentials.ts +++ b/packages/api_platform/src/services/credentials.ts @@ -1,6 +1,7 @@ import { Effect, Data } from "effect"; import { VaultService } from "vault"; import { type ProviderCredentials, Providers } from "common"; +import { getProviderEntry } from "./provider-registry"; export class CredentialsError extends Data.TaggedError("CredentialsError")<{ cause?: unknown; @@ -23,26 +24,5 @@ export const getCredentials = ( ), ); - switch (provider) { - case Providers.AmazonBedrock: - return { - accessKeyId: secrets["accessKeyId"] ?? "", - secretAccessKey: secrets["secretAccessKey"] ?? "", - region: secrets["region"] ?? "", - } satisfies ProviderCredentials as ProviderCredentials; - case Providers.OpenAI: - return { - apiKey: secrets["apiKey"] ?? "", - } satisfies ProviderCredentials as ProviderCredentials; - case Providers.Anthropic: - return { - apiKey: secrets["apiKey"] ?? "", - } satisfies ProviderCredentials as ProviderCredentials; - default: { - const _exhaustiveCheck: never = provider satisfies never; - return yield* Effect.fail( - new CredentialsError({ message: `Unsupported provider: ${provider}` }), - ); - } - } + return getProviderEntry(provider).extractCredentials(secrets); }); diff --git a/packages/api_platform/src/services/pmr.ts b/packages/api_platform/src/services/pmr.ts index d73cd96..cb995af 100644 --- a/packages/api_platform/src/services/pmr.ts +++ b/packages/api_platform/src/services/pmr.ts @@ -16,7 +16,8 @@ export const resolve = ( ) => Effect.gen(function* () { const resolverService = yield* ResolverService; - return yield* resolverService + const startTime = Date.now(); + const pairs = yield* resolverService .resolve(createResponseBody, userId, userProviders, analysisTarget) .pipe( Effect.mapError( @@ -27,4 +28,6 @@ export const resolve = ( }), ), ); + const resolutionLatencyMs = Date.now() - startTime; + return { pairs, resolutionLatencyMs }; }); diff --git a/packages/api_platform/src/services/provider-registry.ts b/packages/api_platform/src/services/provider-registry.ts new file mode 100644 index 0000000..d123989 --- /dev/null +++ b/packages/api_platform/src/services/provider-registry.ts @@ -0,0 +1,35 @@ +import { createAmazonBedrock } from "@ai-sdk/amazon-bedrock"; +import { createOpenAI } from "@ai-sdk/openai"; +import { createAnthropic } from "@ai-sdk/anthropic"; +import { Providers, type ProviderCredentials } from "common"; + +interface ProviderEntry { + createClient: (creds: ProviderCredentials) => (model: string) => ReturnType>; + extractCredentials: (secrets: Record) => ProviderCredentials; +} + +const registry: { [P in Providers]: ProviderEntry

} = { + [Providers.AmazonBedrock]: { + createClient: (creds) => createAmazonBedrock(creds), + extractCredentials: (secrets) => ({ + accessKeyId: secrets["accessKeyId"] ?? "", + secretAccessKey: secrets["secretAccessKey"] ?? "", + region: secrets["region"] ?? "", + }) as ProviderCredentials, + }, + [Providers.OpenAI]: { + createClient: (creds) => createOpenAI(creds), + extractCredentials: (secrets) => ({ + apiKey: secrets["apiKey"] ?? "", + }) as ProviderCredentials, + }, + [Providers.Anthropic]: { + createClient: (creds) => createAnthropic(creds), + extractCredentials: (secrets) => ({ + apiKey: secrets["apiKey"] ?? "", + }) as ProviderCredentials, + }, +}; + +export const getProviderEntry = (provider: T) => + registry[provider] as ProviderEntry; diff --git a/packages/api_platform/src/services/request-context.ts b/packages/api_platform/src/services/request-context.ts index dc8976b..2a9e4fe 100644 --- a/packages/api_platform/src/services/request-context.ts +++ b/packages/api_platform/src/services/request-context.ts @@ -1,8 +1,8 @@ -import type { ProviderModelPair } from "resolver/src/types"; +import type { ProviderModelPair } from "common"; import { Context } from "effect"; -interface RequestContextImpl { +export interface RequestParams { readonly userId: string; readonly userProviders: readonly string[]; readonly fallbackProviderModelPair: ProviderModelPair; @@ -11,5 +11,5 @@ interface RequestContextImpl { export class RequestContext extends Context.Tag("RequestContext")< RequestContext, - RequestContextImpl + RequestParams >() {} diff --git a/packages/api_platform/src/services/responses/index.ts b/packages/api_platform/src/services/responses/index.ts index b17df48..50aa0db 100644 --- a/packages/api_platform/src/services/responses/index.ts +++ b/packages/api_platform/src/services/responses/index.ts @@ -1,27 +1,15 @@ import type { CreateResponseBody, ResponseResource, StreamingEvent } from "common"; -import type { ProviderModelPair } from "resolver/src/types"; import { Effect, Data, Stream } from "effect"; +import { LedgerService } from "ledger"; +import { ResolverService } from "resolver"; + +import type { RequestParams } from "../request-context"; import * as AIService from "../ai"; -import { - DEFAULT_BACKGROUND, - DEFAULT_FREQUENCY_PENALTY, - DEFAULT_PARALLEL_TOOL_CALLS, - DEFAULT_PRESENCE_PENALTY, - DEFAULT_SERVICE_TIER, - DEFAULT_STORE, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_LOGPROBS, - DEFAULT_TOP_P, - DEFAULT_TRUNCATION, -} from "../ai/consts"; -import { convertAISdkStreamTextToStreamingEvents } from "../ai/convertAISdkStreamTextToStreamingEvents"; -import { - resolveTools, - resolveToolChoice, - resolveTextFormat, -} from "../ai/createResponseBodyFieldsToResponseResourceFieldsResolvers"; +import { buildTransaction } from "../ai"; +import { streamToEvents } from "../ai/stream-events"; +import { buildBaseResponse } from "../ai/response-defaults"; import { encodeSSEEvent, encodeSSEDone, encodeSSEToUint8Array } from "../ai/sse"; import * as DatabaseService from "../database/postgres"; @@ -30,94 +18,59 @@ export class ResponseServiceError extends Data.TaggedError("ResponseServiceError message?: string; }> {} -export const create = ( +const buildSkeletonResponse = ( req: CreateResponseBody, - userId: string, - userProviders: readonly string[], - fallbackProviderModelPair: ProviderModelPair, - analysisTarget: string, -) => + resolvedModelAndProvider: { provider: string; model: string }, + responseId: string, + createdAt: number, +): ResponseResource => ({ + object: "response", + id: responseId, + created_at: createdAt, + completed_at: null, + status: "in_progress", + incomplete_details: null, + ...buildBaseResponse(req, resolvedModelAndProvider), + output: [], + error: null, + usage: null, +}); + +export const create = (req: CreateResponseBody, params: RequestParams) => Effect.gen(function* () { - const responseResource = yield* AIService.execute( - req, - userId, - userProviders, - fallbackProviderModelPair, - analysisTarget, - ); + const responseResource = yield* AIService.execute(req, params); yield* persistResponseResourceInDatabase(responseResource); return responseResource; }).pipe( Effect.catchTags({ - AIServiceError: (err) => - Effect.fail(new ResponseServiceError({ cause: err, message: err.message })), - DatabaseServiceError: (err) => + AIServiceError: (err: AIService.AIServiceError) => Effect.fail(new ResponseServiceError({ cause: err, message: err.message })), + DatabaseServiceError: (err: { message?: string }) => + Effect.fail(new ResponseServiceError({ cause: err, message: err.message ?? "Database error" })), }), ); -export const createStream = ( - req: CreateResponseBody, - userId: string, - userProviders: readonly string[], - fallbackProviderModelPair: ProviderModelPair, - analysisTarget: string, -) => +export const createStream = (req: CreateResponseBody, params: RequestParams) => Effect.gen(function* () { - const { result, resolvedModelAndProvider } = yield* AIService.executeStream( - req, - userId, - userProviders, - fallbackProviderModelPair, - analysisTarget, - ); + const ledgerService = yield* LedgerService; + const resolverService = yield* ResolverService; + const { result, resolvedModelAndProvider, resolutionLatencyMs, llmStartedAt, ttftMs } = + yield* AIService.executeStream(req, params); const responseId = crypto.randomUUID(); const createdAt = Date.now(); - const skeletonResponse: ResponseResource = { - object: "response", - id: responseId, - created_at: createdAt, - completed_at: null, - status: "in_progress", - incomplete_details: null, - model: `${resolvedModelAndProvider.provider}/${resolvedModelAndProvider.model}`, - previous_response_id: req.previous_response_id ?? null, - instructions: req.instructions ?? null, - output: [], - text: resolveTextFormat(req.text), - top_logprobs: req.top_logprobs ?? DEFAULT_TOP_LOGPROBS, - reasoning: req.reasoning - ? { effort: req.reasoning.effort ?? null, summary: req.reasoning.summary ?? null } - : null, - error: null, - tools: resolveTools(req.tools), - tool_choice: resolveToolChoice(req.tool_choice), - truncation: req.truncation ?? DEFAULT_TRUNCATION, - parallel_tool_calls: req.parallel_tool_calls ?? DEFAULT_PARALLEL_TOOL_CALLS, - top_p: req.top_p ?? DEFAULT_TOP_P, - presence_penalty: req.presence_penalty ?? DEFAULT_PRESENCE_PENALTY, - frequency_penalty: req.frequency_penalty ?? DEFAULT_FREQUENCY_PENALTY, - temperature: req.temperature ?? DEFAULT_TEMPERATURE, - usage: null, - max_output_tokens: req.max_output_tokens ?? null, - max_tool_calls: req.max_tool_calls ?? null, - store: req.store ?? DEFAULT_STORE, - background: req.background ?? DEFAULT_BACKGROUND, - service_tier: req.service_tier ?? DEFAULT_SERVICE_TIER, - metadata: req.metadata ?? null, - safety_identifier: req.safety_identifier ?? null, - prompt_cache_key: req.prompt_cache_key ?? null, - }; - - const { events, getAccumulatedState } = convertAISdkStreamTextToStreamingEvents( - result.fullStream, - 2, + const skeletonResponse = buildSkeletonResponse( + req, + resolvedModelAndProvider, + responseId, + createdAt, ); + const { events, getAccumulatedState } = streamToEvents(result.fullStream, 2); + let finalResponse: ResponseResource = skeletonResponse; const lifecycleStream = Stream.make( @@ -139,7 +92,8 @@ export const createStream = ( const deltaStream = Stream.fromAsyncIterable( events as AsyncIterable, - (e) => new ResponseServiceError({ cause: e, message: "Error during stream processing" }), + (e: unknown) => + new ResponseServiceError({ cause: e, message: "Error during stream processing" }), ).pipe( Stream.flatMap((eventArray) => Stream.fromIterable(eventArray)), Stream.map((event) => encodeSSEToUint8Array(encodeSSEEvent(event.type, event))), @@ -180,7 +134,7 @@ export const createStream = ( encodeSSEToUint8Array(encodeSSEDone()), ]; }, - catch: (e) => + catch: (e: unknown) => new ResponseServiceError({ cause: e, message: "Error building completion events" }), }), ).pipe(Stream.flatMap((arr) => Stream.fromIterable(arr))); @@ -216,6 +170,26 @@ export const createStream = ( if (req.store !== false) { yield* persistResponseResourceInDatabase(finalResponse); } + const cost = yield* resolverService + .getCostForModel( + `${resolvedModelAndProvider.provider}/${resolvedModelAndProvider.model}`, + ) + .pipe(Effect.catchAll(() => Effect.succeed(null))); + const totalLatencyMs = Date.now() - llmStartedAt; + yield* ledgerService + .insertTransaction( + buildTransaction({ + resolvedModelAndProvider, + resolutionLatencyMs, + userId: params.userId, + isStreaming: true, + response: finalResponse, + cost, + totalLatencyMs, + ttftMs, + }), + ) + .pipe(Effect.ignore); }).pipe(Effect.ignore), ), ); @@ -223,7 +197,7 @@ export const createStream = ( return sseStream; }).pipe( Effect.catchTags({ - AIServiceError: (err) => + AIServiceError: (err: AIService.AIServiceError) => Effect.fail(new ResponseServiceError({ cause: err, message: err.message })), }), ); diff --git a/packages/backend/src/services/apikey.ts b/packages/backend/src/services/apikey.ts index 390311a..bd192f7 100644 --- a/packages/backend/src/services/apikey.ts +++ b/packages/backend/src/services/apikey.ts @@ -1,4 +1,4 @@ -import type { ProviderModelPair } from "resolver"; +import type { VerifyApiKeyResult } from "common"; import { Context, Effect, Layer } from "effect"; import { ProviderModelParseError } from "resolver"; @@ -21,15 +21,6 @@ export interface ApiKeyResponse { readonly value?: string; } -export interface VerifyResult { - readonly valid: boolean; - readonly providers?: string[]; - readonly userId?: string; - readonly fallbackProviderModelPair?: ProviderModelPair; - readonly analysisTarget?: string; - readonly error?: string; -} - const toApiKeyResponse = ( key: { id: string; @@ -56,6 +47,42 @@ const tryAuth = (f: () => Promise, message: string) => catch: (cause) => new AuthApiError({ cause, message }), }); +const assembleUserContext = ( + secrets: { + getUserSecrets: (userId: string) => Effect.Effect< + { providers: string[]; disabledProviders: string[] }, + DatabaseServiceError + >; + getUserFallback: (userId: string) => Effect.Effect; + getUserAnalysisTarget: (userId: string) => Effect.Effect; + }, + userId: string, +) => + Effect.gen(function* () { + const userSecrets = yield* secrets + .getUserSecrets(userId) + .pipe( + Effect.catchTag("DatabaseServiceError", (err) => + Effect.logError("Failed to fetch user providers during verify").pipe( + Effect.annotateLogs("cause", String(err.cause)), + Effect.as({ providers: [] as string[], disabledProviders: [] as string[] }), + ), + ), + ); + const providers = userSecrets.providers.filter( + (p) => !userSecrets.disabledProviders.includes(p), + ); + + const fallbackProviderModelPairString = yield* secrets.getUserFallback(userId); + const fallbackProviderModelPair = yield* parseProviderModelImpl( + fallbackProviderModelPairString, + ); + + const analysisTarget = yield* secrets.getUserAnalysisTarget(userId); + + return { providers, fallbackProviderModelPair, analysisTarget }; + }); + interface ApiKeyServiceImpl { getKey: (headers: Headers) => Effect.Effect; @@ -75,7 +102,7 @@ interface ApiKeyServiceImpl { verifyKey: ( key: string, - ) => Effect.Effect; + ) => Effect.Effect; } export class ApiKeyService extends Context.Tag("ApiKeyService")< @@ -180,35 +207,14 @@ export const ApiKeyServiceLive = Layer.effect( }; } - let providers: string[] = []; - - const userSecrets = yield* secrets - .getUserSecrets(result.key!.referenceId) - .pipe( - Effect.catchTag("DatabaseServiceError", (err) => - Effect.logError("Failed to fetch user providers during verify").pipe( - Effect.annotateLogs("cause", String(err.cause)), - Effect.as({ providers: [] as string[], disabledProviders: [] as string[] }), - ), - ), - ); - providers = userSecrets.providers.filter( - (p) => !userSecrets.disabledProviders.includes(p), - ); - - const fallbackProviderModelPairString = yield* secrets.getUserFallback( - result.key!.referenceId, - ); - const fallbackProviderModelPair = yield* parseProviderModelImpl( - fallbackProviderModelPairString, - ); - - const analysisTarget = yield* secrets.getUserAnalysisTarget(result.key!.referenceId); + const userId = result.key!.referenceId; + const { providers, fallbackProviderModelPair, analysisTarget } = + yield* assembleUserContext(secrets, userId); return { valid: true as const, providers, - userId: result.key!.referenceId, + userId, fallbackProviderModelPair, analysisTarget, }; diff --git a/packages/common/src/api-key.ts b/packages/common/src/api-key.ts new file mode 100644 index 0000000..2483d60 --- /dev/null +++ b/packages/common/src/api-key.ts @@ -0,0 +1,10 @@ +import type { ProviderModelPair } from "./intent"; + +export interface VerifyApiKeyResult { + readonly valid: boolean; + readonly userId?: string; + readonly providers?: string[]; + readonly fallbackProviderModelPair?: ProviderModelPair; + readonly analysisTarget?: string; + readonly error?: string; +} diff --git a/packages/common/src/index.ts b/packages/common/src/index.ts index cc809bf..adb4384 100644 --- a/packages/common/src/index.ts +++ b/packages/common/src/index.ts @@ -1,3 +1,5 @@ export * from "./providers"; export * from "./schema"; export * from "./resolver"; +export * from "./intent"; +export * from "./api-key"; diff --git a/packages/common/src/intent.ts b/packages/common/src/intent.ts new file mode 100644 index 0000000..c8dd64a --- /dev/null +++ b/packages/common/src/intent.ts @@ -0,0 +1,53 @@ +import { Data, Schema } from "effect"; + +export const Intent = Schema.Literal( + "auto", + "academia", + "finance", + "health", + "legal", + "marketing", + "programming", + "roleplay", + "science", + "seo", + "technology", + "translation", + "trivia", +); +export type Intent = Schema.Schema.Type; + +export const INTENTS: ReadonlyArray = Intent.literals; + +/** All intents except "auto", used for data fetching categories. */ +export const CATEGORIES: ReadonlyArray> = INTENTS.filter( + (i): i is Exclude => i !== "auto", +); + +export const IntentPolicy = Schema.Literal( + "auto", + "most-popular", + "pricing-low-to-high", + "pricing-high-to-low", + "context-high-to-low", + "latency-low-to-high", + "throughput-high-to-low", +); +export type IntentPolicy = Schema.Schema.Type; + +export const INTENT_POLICIES: ReadonlyArray = IntentPolicy.literals; + +/** All policies except "auto", used for data fetching orders. */ +export const ORDERS: ReadonlyArray> = INTENT_POLICIES.filter( + (i): i is Exclude => i !== "auto", +); + +export class IntentPair extends Data.TaggedClass("IntentPair")<{ + readonly intent: Intent; + readonly intentPolicy: IntentPolicy; +}> {} + +export class ProviderModelPair extends Data.TaggedClass("ProviderModelPair")<{ + readonly model: string; + readonly provider: string; +}> {} diff --git a/packages/common/src/resolver.ts b/packages/common/src/resolver.ts index 667a84e..09db471 100644 --- a/packages/common/src/resolver.ts +++ b/packages/common/src/resolver.ts @@ -3,6 +3,7 @@ import { Schema } from "effect"; export const ResolvedResponseSchema = Schema.Struct({ model: Schema.String, provider: Schema.String, + category: Schema.NullOr(Schema.String), }); export type ResolvedResponse = Schema.Schema.Type; diff --git a/packages/resolver/src/data_manager/fetch.ts b/packages/resolver/src/data_manager/fetch.ts index c97815f..ca49068 100644 --- a/packages/resolver/src/data_manager/fetch.ts +++ b/packages/resolver/src/data_manager/fetch.ts @@ -1,11 +1,12 @@ +import { SUPPORTED_PROVIDERS } from "common"; import { Effect, Duration, Schema } from "effect"; + +import * as Redis from "../redis/index"; import { DataFetchError } from "../types"; import { ORDERS, CATEGORIES } from "../types"; -import { SUPPORTED_PROVIDERS } from "common"; -import { ProviderModelMapSchema } from "./schema/modelsdev"; -import { OpenRouterMapSchema } from "./schema/openrouter"; import { generateModelMap } from "./model_map"; -import * as Redis from "../redis/index"; +import { ProviderModelMapSchema, ProviderModelToCostSchema } from "./schema/modelsdev"; +import { OpenRouterMapSchema } from "./schema/openrouter"; const OPENROUTER_BASE = "https://openrouter.ai/api/frontend/models"; const MODELS_DEV_BASE = "https://models.dev/api.json"; @@ -129,6 +130,32 @@ const modelsDevAction = () => (json: Record) => ); yield* Redis.bulkSetModelsForProvider(supported); + + const parsedCost = yield* Schema.decodeUnknown(ProviderModelToCostSchema)(json).pipe( + Effect.tapError((err) => + Effect.logError("models.dev cost decode failed").pipe( + Effect.annotateLogs({ + service: "DataManager", + operation: "populate", + key: `models.dev`, + cause: String(err), + }), + ), + ), + ); + + const supportedCost: Record = {}; + for (const provider of SUPPORTED_PROVIDERS) { + const costs = parsedCost[provider]; + if (costs) { + for (const cost of costs) { + supportedCost[`${provider}/${cost.model}`] = { input: cost.input, output: cost.output }; + } + } + } + + yield* Redis.bulkSetProviderModelCost(supportedCost); + return supported; }); @@ -171,7 +198,6 @@ const populate = () => }), ); - // TODO: do a runtime check to see if they match our expected schema, and log any discrepancies for debugging purposes const modelMap = generateModelMap( openRouter.flat() as string[], modelsDev as Readonly>, diff --git a/packages/resolver/src/data_manager/model_map.ts b/packages/resolver/src/data_manager/model_map.ts index e4e6f8e..8041650 100644 --- a/packages/resolver/src/data_manager/model_map.ts +++ b/packages/resolver/src/data_manager/model_map.ts @@ -30,7 +30,7 @@ export function generateModelMap( for (const entry of parsedProviderModels) { if (modelsMatch(parsedSlug, entry.parsed)) { - matches.push({ provider: entry.provider, model: entry.model }); + matches.push({ provider: entry.provider, model: entry.model, category: null }); } } diff --git a/packages/resolver/src/data_manager/schema/modelsdev.ts b/packages/resolver/src/data_manager/schema/modelsdev.ts index 5f20d9a..2d84bcc 100644 --- a/packages/resolver/src/data_manager/schema/modelsdev.ts +++ b/packages/resolver/src/data_manager/schema/modelsdev.ts @@ -3,7 +3,14 @@ import { Schema } from "effect"; const ProviderSchema = Schema.Struct({ models: Schema.Record({ key: Schema.String, - value: Schema.Unknown, + value: Schema.Struct({ + cost: Schema.optional( + Schema.Struct({ + input: Schema.Number, + output: Schema.Number, + }), + ), + }), }), }); @@ -30,3 +37,32 @@ export const ProviderModelMapSchema = Schema.transform( encode: (a) => a, }, ); + +export const ProviderModelToCostSchema = Schema.transform( + ProvidersSchema, + Schema.Record({ + key: Schema.String, + value: Schema.Array( + Schema.Struct({ + input: Schema.Number, + output: Schema.Number, + model: Schema.String, + }), + ), + }), + { + strict: false, + decode: (providers) => + Object.fromEntries( + Object.entries(providers).map(([provider, { models }]) => [ + provider, + Object.entries(models).map(([modelName, { cost }]) => ({ + input: cost?.input ?? 0, + output: cost?.output ?? 0, + model: modelName, + })), + ]), + ), + encode: (a) => a, + }, +); diff --git a/packages/resolver/src/index.ts b/packages/resolver/src/index.ts index c6203f2..6187eeb 100644 --- a/packages/resolver/src/index.ts +++ b/packages/resolver/src/index.ts @@ -35,6 +35,7 @@ export class ResolverService extends Context.Tag("ResolverService")< readonly { readonly model: string; readonly provider: string; + readonly category: string | null; }[], | Redis.RedisError | ParseError @@ -44,6 +45,12 @@ export class ResolverService extends Context.Tag("ResolverService")< | NoProviderAvailableError | ProviderModelParseError >; + getCostForModel: ( + canonicalProviderModelName: string, + ) => Effect.Effect< + { input: number; output: number } | null, + ParseError | Redis.RedisError + >; } >() {} @@ -84,6 +91,13 @@ export const ResolverServiceLive = Layer.effect( return yield* Redis.getAllModelsGroupedByProvider(); }).pipe(Effect.provideService(Redis.Redis, redis)); }, + getCostForModel(canonicalProviderModelName) { + return Effect.gen(function* () { + const cost = yield* Redis.getCostForModel(canonicalProviderModelName); + if (Array.isArray(cost) && cost.length === 0) return null; + return cost as { input: number; output: number }; + }).pipe(Effect.provideService(Redis.Redis, redis)); + }, }); }), ).pipe(Layer.provide(Redis.fromEnv)); diff --git a/packages/resolver/src/parser/index.ts b/packages/resolver/src/parser/index.ts index 84a0092..5771e77 100644 --- a/packages/resolver/src/parser/index.ts +++ b/packages/resolver/src/parser/index.ts @@ -7,20 +7,8 @@ export * from "../parser/parse_provider_model"; export const parseImpl = (model: string) => Effect.gen(function* () { - yield* Effect.logDebug("Parsing model string").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parse", model }), - ); - const firstSlashIndex = model.indexOf("/"); if (firstSlashIndex === -1) { - yield* Effect.logWarning("Model string missing separator").pipe( - Effect.annotateLogs({ - service: "Parser", - operation: "parse", - model, - reason: "BadFormatting", - }), - ); return yield* new ProviderModelParseError({ reason: "BadFormatting", message: `Expected format "{}/{}", got: "${model}"`, @@ -29,26 +17,8 @@ export const parseImpl = (model: string) => const prefix = model.substring(0, firstSlashIndex); if (isIntent(prefix)) { - yield* Effect.logDebug("Parsed as intent").pipe( - Effect.annotateLogs({ - service: "Parser", - operation: "parse", - model, - prefix, - type: "IntentPair", - }), - ); return yield* parseIntentImpl(model); } - yield* Effect.logDebug("Parsed as provider/model").pipe( - Effect.annotateLogs({ - service: "Parser", - operation: "parse", - model, - prefix, - type: "ProviderModelPair", - }), - ); return yield* parseProviderModelImpl(model); }); diff --git a/packages/resolver/src/parser/match_models.ts b/packages/resolver/src/parser/match_models.ts index f6b4ba3..84c6398 100644 --- a/packages/resolver/src/parser/match_models.ts +++ b/packages/resolver/src/parser/match_models.ts @@ -1,4 +1,4 @@ -import type { ParsedModelId } from "./parse_model_id.ts"; +import type { ParsedModelId } from "./parse_model_id"; export function modelsMatch(a: ParsedModelId, b: ParsedModelId): boolean { return ( diff --git a/packages/resolver/src/parser/parse_intent.ts b/packages/resolver/src/parser/parse_intent.ts index 0ecb987..b020529 100644 --- a/packages/resolver/src/parser/parse_intent.ts +++ b/packages/resolver/src/parser/parse_intent.ts @@ -8,14 +8,6 @@ export const parseIntentImpl = (input: string) => const firstSlashIndex = input.indexOf("/"); if (firstSlashIndex === -1) { - yield* Effect.logWarning("Intent parse failed: missing separator").pipe( - Effect.annotateLogs({ - service: "Parser", - operation: "parseIntent", - input, - reason: "BadFormatting", - }), - ); return yield* new IntentParseError({ reason: "BadFormatting", message: `Expected format "intent/intentPolicy", got: "${input}"`, @@ -26,14 +18,6 @@ export const parseIntentImpl = (input: string) => const rawPolicy = input.substring(firstSlashIndex + 1); if (!rawIntent) { - yield* Effect.logWarning("Intent parse failed: empty intent").pipe( - Effect.annotateLogs({ - service: "Parser", - operation: "parseIntent", - input, - reason: "EmptyIntent", - }), - ); return yield* new IntentParseError({ reason: "EmptyIntent", message: `Intent must be non-empty, got: "${input}"`, @@ -41,14 +25,6 @@ export const parseIntentImpl = (input: string) => } if (!rawPolicy) { - yield* Effect.logWarning("Intent parse failed: empty policy").pipe( - Effect.annotateLogs({ - service: "Parser", - operation: "parseIntent", - input, - reason: "EmptyIntentPolicy", - }), - ); return yield* new IntentParseError({ reason: "EmptyIntentPolicy", message: `intentPolicy must be non-empty, got: "${input}"`, @@ -56,11 +32,6 @@ export const parseIntentImpl = (input: string) => } const intent = yield* Schema.decodeUnknown(Intent)(rawIntent.toLowerCase()).pipe( - Effect.tapError(() => - Effect.logWarning("Intent parse failed: invalid intent literal").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseIntent", rawIntent }), - ), - ), Effect.mapError( () => new IntentParseError({ @@ -71,11 +42,6 @@ export const parseIntentImpl = (input: string) => ); const intentPolicy = yield* Schema.decodeUnknown(IntentPolicy)(rawPolicy.toLowerCase()).pipe( - Effect.tapError(() => - Effect.logWarning("Intent parse failed: invalid policy literal").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseIntent", rawPolicy }), - ), - ), Effect.mapError( () => new IntentParseError({ @@ -85,9 +51,5 @@ export const parseIntentImpl = (input: string) => ), ); - yield* Effect.logDebug("Intent parsed").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseIntent", intent, intentPolicy }), - ); - return new IntentPair({ intent, intentPolicy }); }); diff --git a/packages/resolver/src/parser/parse_provider_model.ts b/packages/resolver/src/parser/parse_provider_model.ts index c9096be..95b028c 100644 --- a/packages/resolver/src/parser/parse_provider_model.ts +++ b/packages/resolver/src/parser/parse_provider_model.ts @@ -6,9 +6,6 @@ export const parseProviderModelImpl = (input: string) => const firstSlashIndex = input.indexOf("/"); if (firstSlashIndex === -1) { - yield* Effect.logWarning("Provider/model parse failed: missing separator").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseProviderModel", input, reason: "BadFormatting" }), - ); return yield* new ProviderModelParseError({ reason: "BadFormatting", message: `Expected format "provider/model", got: "${input}"`, @@ -19,9 +16,6 @@ export const parseProviderModelImpl = (input: string) => const model = input.substring(firstSlashIndex + 1); if (!provider) { - yield* Effect.logWarning("Provider/model parse failed: empty provider").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseProviderModel", input, reason: "EmptyProvider" }), - ); return yield* new ProviderModelParseError({ reason: "EmptyProvider", message: `Provider must be non-empty, got: "${input}"`, @@ -29,18 +23,11 @@ export const parseProviderModelImpl = (input: string) => } if (!model) { - yield* Effect.logWarning("Provider/model parse failed: empty model").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseProviderModel", input, reason: "EmptyModel" }), - ); return yield* new ProviderModelParseError({ reason: "EmptyModel", message: `Model must be non-empty, got: "${input}"`, }); } - yield* Effect.logDebug("Provider/model parsed").pipe( - Effect.annotateLogs({ service: "Parser", operation: "parseProviderModel", provider, model }), - ); - return new ProviderModelPair({ provider, model }); }); diff --git a/packages/resolver/src/redis/categories.ts b/packages/resolver/src/redis/categories.ts index fbd5fee..b6c7d32 100644 --- a/packages/resolver/src/redis/categories.ts +++ b/packages/resolver/src/redis/categories.ts @@ -1,11 +1,11 @@ import { Effect, Schema } from "effect"; import { Redis } from "."; -import { TTL } from "./consts"; +import { TTL, REDIS_PREFIX } from "./consts"; -const PREFIX = "enfinyte:categories"; const modelsSchemaParser = Schema.parseJson(Schema.Array(Schema.String)); -const buildKey = (category: string, order: string) => `${PREFIX}:${category}:${order}`; +const buildKey = (category: string, order: string) => + `${REDIS_PREFIX.categories}:${category}:${order}`; export const getModelsForCategoryAndOrder = (category: string, order: string) => Effect.gen(function* () { diff --git a/packages/resolver/src/redis/classification.ts b/packages/resolver/src/redis/classification.ts index 1a15101..dae18a3 100644 --- a/packages/resolver/src/redis/classification.ts +++ b/packages/resolver/src/redis/classification.ts @@ -4,12 +4,11 @@ import { hoursToMilliseconds } from "date-fns"; import { Effect, Schema } from "effect"; import { Redis } from "./index"; +import { REDIS_PREFIX } from "./consts"; const MAX_ENTRIES_PER_USER = 100; const TTL_MS = hoursToMilliseconds(1); -const KEY_PREFIX = "classification_cache"; - const resolvedResponseSchema = Schema.parseJson(Schema.Array(ResolvedResponseSchema)); export const getResolvedResponse = (userId: string, systemPromptText: string) => @@ -76,9 +75,9 @@ const hashText = (text: string): string => { }; const buildEntryKey = (userId: string, hashedString: string): string => { - return `${KEY_PREFIX}:${userId}:${hashedString}`; + return `${REDIS_PREFIX.classificationCache}:${userId}:${hashedString}`; }; const buildIndexKey = (userId: string): string => { - return `${KEY_PREFIX}:${userId}:index`; + return `${REDIS_PREFIX.classificationCache}:${userId}:index`; }; diff --git a/packages/resolver/src/redis/consts.ts b/packages/resolver/src/redis/consts.ts index 24f90ca..ff266d0 100644 --- a/packages/resolver/src/redis/consts.ts +++ b/packages/resolver/src/redis/consts.ts @@ -1,3 +1,12 @@ import { hoursToMilliseconds } from "date-fns"; export const TTL = hoursToMilliseconds(10); + +export const REDIS_PREFIX = { + lastFetchPoint: "enfinyte:lastFetchPoint", + categories: "enfinyte:categories", + modelToProviders: "enfinyte:model_to_providers:", + providerToModels: "enfinyte:provider_to_models:", + modelToCost: "enfinyte:model_to_cost:", + classificationCache: "classification_cache", +} as const; diff --git a/packages/resolver/src/redis/fetch_point.ts b/packages/resolver/src/redis/fetch_point.ts index 26abca3..814babb 100644 --- a/packages/resolver/src/redis/fetch_point.ts +++ b/packages/resolver/src/redis/fetch_point.ts @@ -1,13 +1,12 @@ import { Clock, Effect } from "effect"; -import { Redis } from "."; -import { TTL } from "./consts"; -const LAST_FETCH_POINT_KEY = "enfinyte:lastFetchPoint"; +import { Redis } from "."; +import { TTL, REDIS_PREFIX } from "./consts"; export const getLastFetchPoint = () => Effect.gen(function* () { const redis = yield* Redis; - const lastFetchPoint = yield* redis.use((client) => client.get(LAST_FETCH_POINT_KEY)); + const lastFetchPoint = yield* redis.use((client) => client.get(REDIS_PREFIX.lastFetchPoint)); return lastFetchPoint === null ? lastFetchPoint : parseInt(lastFetchPoint); }); @@ -15,5 +14,5 @@ export const markLastFetchPoint = () => Effect.gen(function* () { const redis = yield* Redis; const now = yield* Clock.currentTimeMillis; - yield* redis.use((client) => client.set(LAST_FETCH_POINT_KEY, now.toString(), "PX", TTL)); + yield* redis.use((client) => client.set(REDIS_PREFIX.lastFetchPoint, now.toString(), "PX", TTL)); }); diff --git a/packages/resolver/src/redis/index.ts b/packages/resolver/src/redis/index.ts index 5b2b06f..cec08d7 100644 --- a/packages/resolver/src/redis/index.ts +++ b/packages/resolver/src/redis/index.ts @@ -64,3 +64,4 @@ export * from "./model_to_providers"; export * from "./categories"; export * from "./provider_to_models"; export * from "./classification"; +export * from "./model_cost"; diff --git a/packages/resolver/src/redis/model_cost.ts b/packages/resolver/src/redis/model_cost.ts new file mode 100644 index 0000000..f79ce11 --- /dev/null +++ b/packages/resolver/src/redis/model_cost.ts @@ -0,0 +1,55 @@ +import { Effect, Schema } from "effect"; + +import { Redis } from "."; +import { TTL, REDIS_PREFIX } from "./consts"; + +const costSchemaParser = Schema.parseJson( + Schema.Struct({ + input: Schema.Number, + output: Schema.Number, + }), +); + +type Cost = typeof costSchemaParser.Type; + +export const getCostForModel = (canonicalProviderModelName: string) => + Effect.gen(function* () { + const redis = yield* Redis; + const costStr = yield* redis.use((client) => + client.get(REDIS_PREFIX.modelToCost + canonicalProviderModelName), + ); + if (!costStr) return []; + return yield* Schema.decodeUnknown(costSchemaParser)(costStr); + }); + +export const setCostForModel = (canonicalProviderModelName: string, cost: Cost) => + Effect.gen(function* () { + const redis = yield* Redis; + const stringifiedCost = yield* Schema.encode(costSchemaParser)(cost); + yield* redis.use((client) => + client.set( + REDIS_PREFIX.modelToCost + canonicalProviderModelName, + stringifiedCost, + "PX", + TTL, + ), + ); + }); + +export const bulkSetProviderModelCost = ( + entries: Record, +) => + Effect.gen(function* () { + const setterEffects = Object.entries(entries).map(([canonicalProvdierModelName, cost]) => + setCostForModel(canonicalProvdierModelName, cost), + ); + yield* Effect.all(setterEffects, { concurrency: 5 }); + }); + +export const deleteProviderModelCost = (canonicalProviderModelName: string) => + Effect.gen(function* () { + const redis = yield* Redis; + yield* redis.use((client) => + client.del(REDIS_PREFIX.modelToCost + canonicalProviderModelName), + ); + }); diff --git a/packages/resolver/src/redis/model_to_providers.ts b/packages/resolver/src/redis/model_to_providers.ts index 9261a2c..bac2cae 100644 --- a/packages/resolver/src/redis/model_to_providers.ts +++ b/packages/resolver/src/redis/model_to_providers.ts @@ -1,9 +1,8 @@ import { Effect, Schema } from "effect"; import { Redis } from "."; -import { TTL } from "./consts"; +import { TTL, REDIS_PREFIX } from "./consts"; -const PREFIX = "enfinyte:model_to_providers:"; const providersSchemaParser = Schema.parseJson( Schema.Array( Schema.Struct({ @@ -18,7 +17,9 @@ type providers = typeof providersSchemaParser.Type; export const getProvidersForModel = (canonicalModelName: string) => Effect.gen(function* () { const redis = yield* Redis; - const providersStr = yield* redis.use((client) => client.get(PREFIX + canonicalModelName)); + const providersStr = yield* redis.use((client) => + client.get(REDIS_PREFIX.modelToProviders + canonicalModelName), + ); if (!providersStr) return []; return yield* Schema.decodeUnknown(providersSchemaParser)(providersStr); }); @@ -28,7 +29,12 @@ export const setProvidersForModel = (canonicalModelName: string, providers: prov const redis = yield* Redis; const stringifiedProviders = yield* Schema.encode(providersSchemaParser)(providers); yield* redis.use((client) => - client.set(PREFIX + canonicalModelName, stringifiedProviders, "PX", TTL), + client.set( + REDIS_PREFIX.modelToProviders + canonicalModelName, + stringifiedProviders, + "PX", + TTL, + ), ); }); @@ -45,5 +51,5 @@ export const bulkSetProvidersForModels = ( export const deleteModel = (canonicalModelName: string) => Effect.gen(function* () { const redis = yield* Redis; - yield* redis.use((client) => client.del(PREFIX + canonicalModelName)); + yield* redis.use((client) => client.del(REDIS_PREFIX.modelToProviders + canonicalModelName)); }); diff --git a/packages/resolver/src/redis/provider_to_models.ts b/packages/resolver/src/redis/provider_to_models.ts index 06b9ddd..b10427b 100644 --- a/packages/resolver/src/redis/provider_to_models.ts +++ b/packages/resolver/src/redis/provider_to_models.ts @@ -1,9 +1,7 @@ import { Effect, Schema } from "effect"; import { SUPPORTED_PROVIDERS } from "common"; import { Redis } from "."; -import { TTL } from "./consts"; - -const PREFIX = "enfinyte:provider_to_models:"; +import { TTL, REDIS_PREFIX } from "./consts"; const modelsSchemaParser = Schema.parseJson(Schema.Array(Schema.String)); type models = typeof modelsSchemaParser.Type; @@ -12,7 +10,9 @@ export const getModelsForProvider = (provider: string) => Effect.gen(function* () { const redis = yield* Redis; - const modelsStr = yield* redis.use((client) => client.get(PREFIX + provider)); + const modelsStr = yield* redis.use((client) => + client.get(REDIS_PREFIX.providerToModels + provider), + ); if (!modelsStr) return []; return yield* Schema.decodeUnknown(modelsSchemaParser)(modelsStr); @@ -22,7 +22,9 @@ export const setModelsForProvider = (provider: string, models: readonly string[] Effect.gen(function* () { const redis = yield* Redis; const stringifiedModels = yield* Schema.encode(modelsSchemaParser)(models); - yield* redis.use((client) => client.set(PREFIX + provider, stringifiedModels, "PX", TTL)); + yield* redis.use((client) => + client.set(REDIS_PREFIX.providerToModels + provider, stringifiedModels, "PX", TTL), + ); }); export const bulkSetModelsForProvider = (entries: Record) => @@ -37,7 +39,7 @@ export const bulkSetModelsForProvider = (entries: Record Effect.gen(function* () { const redis = yield* Redis; - const keys = SUPPORTED_PROVIDERS.map((p: string) => PREFIX + p); + const keys = SUPPORTED_PROVIDERS.map((p: string) => REDIS_PREFIX.providerToModels + p); const values = yield* redis.use((client) => client.mget(...keys)); @@ -62,5 +64,5 @@ export const getAllModelsGroupedByProvider = () => export const deleteProvider = (provider: string) => Effect.gen(function* () { const redis = yield* Redis; - yield* redis.use((client) => client.del(PREFIX + provider)); + yield* redis.use((client) => client.del(REDIS_PREFIX.providerToModels + provider)); }); diff --git a/packages/resolver/src/resolver/index.ts b/packages/resolver/src/resolver/index.ts index b4b5248..8f22db8 100644 --- a/packages/resolver/src/resolver/index.ts +++ b/packages/resolver/src/resolver/index.ts @@ -1,8 +1,6 @@ import type { CreateResponseBody } from "common"; -import { Effect, Match, pipe } from "effect"; - -import type { IntentPair, ProviderModelPair } from "../types"; +import { Effect, pipe } from "effect"; import { parseImpl } from "../parser"; import { parseIntentImpl } from "../parser/parse_intent"; @@ -11,13 +9,6 @@ import { resolveAuto } from "./resolve_auto"; import { resolveIntentPair } from "./resolve_intent"; import { resolveProviderModelPair } from "./resolve_provider_model"; -const resolve = (userProviders: string[]) => - Match.type().pipe( - Match.tag("IntentPair", (pair) => resolveIntentPair(pair, userProviders)), - Match.tag("ProviderModelPair", resolveProviderModelPair), - Match.exhaustive, - ); - export const resolveImpl = ( options: Pick, userId: string, @@ -26,13 +17,6 @@ export const resolveImpl = ( ) => Effect.gen(function* () { if (typeof options.model !== "string") { - yield* Effect.logError("Invalid model type").pipe( - Effect.annotateLogs({ - service: "Resolver", - operation: "resolve", - modelType: typeof options.model, - }), - ); return yield* new ResolveError({ reason: "InvalidModelType", message: `Expected model to be a string, got ${typeof options.model}`, @@ -40,14 +24,6 @@ export const resolveImpl = ( } if (options.model.startsWith("auto")) { - yield* Effect.logInfo("Resolving via auto-classification").pipe( - Effect.annotateLogs({ - service: "Resolver", - operation: "resolve", - model: options.model, - providerCount: userProviders.length, - }), - ); return yield* pipe( options.model, parseIntentImpl, @@ -55,14 +31,12 @@ export const resolveImpl = ( ); } - yield* Effect.logInfo("Resolving via direct parse").pipe( - Effect.annotateLogs({ - service: "Resolver", - operation: "resolve", - model: options.model, - providerCount: userProviders.length, + return yield* pipe( + options.model, + parseImpl, + Effect.flatMap((parsed) => { + if (parsed._tag === "IntentPair") return resolveIntentPair(parsed, userProviders); + return resolveProviderModelPair(parsed); }), ); - - return yield* pipe(options.model, parseImpl, Effect.flatMap(resolve(userProviders))); }); diff --git a/packages/resolver/src/resolver/resolve_auto.ts b/packages/resolver/src/resolver/resolve_auto.ts index 85db3a7..c219445 100644 --- a/packages/resolver/src/resolver/resolve_auto.ts +++ b/packages/resolver/src/resolver/resolve_auto.ts @@ -14,108 +14,66 @@ import { resolveIntentPair } from "./resolve_intent"; const RETRY_POLICY = { times: 5 }; const LLM_MODEL = "moonshotai.kimi-k2.5"; -const getCategory = (prompt: string) => +const classifyWithLLM = ( + prompt: string, + systemPrompt: string, + schema: z.ZodType, + fieldName: string, + operationName: string, +) => Effect.tryPromise({ try: () => generateText({ model: bedrock(LLM_MODEL), - system: SYSTEM_PROMPT_CAT, - output: Output.object({ - schema: z.object({ - category: z.enum(CATEGORIES), - }), - }), - messages: [ - { - role: "user", - content: [ - { - type: "text", - text: prompt, - }, - ], - }, - ], - }).then((res) => res.output.category), + system: systemPrompt, + output: Output.object({ schema }), + messages: [{ role: "user", content: [{ type: "text", text: prompt }] }], + }).then((res) => (res.output as Record)[fieldName]!), catch: (error) => new DataFetchError({ reason: "APICallFailed", - message: "Failed to classify intent category", + message: `Failed to classify ${operationName}`, cause: error, }), }).pipe( Effect.tapError((err) => - Effect.logError("LLM category classification failed").pipe( + Effect.logError(`LLM ${operationName} classification failed`).pipe( Effect.annotateLogs({ service: "Resolver", - operation: "getCategory", + operation: operationName, llmModel: LLM_MODEL, cause: err.cause instanceof Error ? err.cause.message : String(err.cause), }), ), ), - Effect.tap((category) => - Effect.logDebug("LLM category classified").pipe( + Effect.tap((result) => + Effect.logDebug(`LLM ${operationName} classified`).pipe( Effect.annotateLogs({ service: "Resolver", - operation: "getCategory", + operation: operationName, llmModel: LLM_MODEL, - category, + result, }), ), ), ); +const getCategory = (prompt: string) => + classifyWithLLM( + prompt, + SYSTEM_PROMPT_CAT, + z.object({ category: z.enum(CATEGORIES) }), + "category", + "getCategory", + ); + const getPolicy = (prompt: string) => - Effect.tryPromise({ - try: () => - generateText({ - model: bedrock(LLM_MODEL), - system: SYSTEM_PROMPT_POL, - output: Output.object({ - schema: z.object({ - policy: z.enum(ORDERS), - }), - }), - messages: [ - { - role: "user", - content: [ - { - type: "text", - text: prompt, - }, - ], - }, - ], - }).then((res) => res.output.policy), - catch: (error) => - new DataFetchError({ - reason: "APICallFailed", - message: "Failed to classify intent policy", - cause: error, - }), - }).pipe( - Effect.tapError((err) => - Effect.logError("LLM policy classification failed").pipe( - Effect.annotateLogs({ - service: "Resolver", - operation: "getPolicy", - llmModel: LLM_MODEL, - cause: err.cause instanceof Error ? err.cause.message : String(err.cause), - }), - ), - ), - Effect.tap((policy) => - Effect.logDebug("LLM policy classified").pipe( - Effect.annotateLogs({ - service: "Resolver", - operation: "getPolicy", - llmModel: LLM_MODEL, - policy, - }), - ), - ), + classifyWithLLM( + prompt, + SYSTEM_PROMPT_POL, + z.object({ policy: z.enum(ORDERS) }), + "policy", + "getPolicy", ); const extractTextFromInput = ( @@ -148,9 +106,6 @@ const resolveWith = ( userProviders: string[], ) => Effect.gen(function* () { - console.log("---------------------"); - console.log(prompt); - console.log("---------------------"); yield* Effect.logInfo("Auto-classifying with LLM").pipe( Effect.annotateLogs({ service: "Resolver", @@ -186,13 +141,12 @@ const resolveWith = ( }), ); - if (pair.intentPolicy === "auto") { - yield* Effect.logDebug("Policy is auto, classifying via LLM").pipe( - Effect.annotateLogs({ service: "Resolver", operation: "resolveWith" }), - ); - - const policy = yield* Effect.retry(getPolicy(prompt.text), RETRY_POLICY); + const policy = + pair.intentPolicy === "auto" + ? yield* Effect.retry(getPolicy(prompt.text), RETRY_POLICY) + : pair.intentPolicy; + if (pair.intentPolicy === "auto") { yield* Effect.logInfo("Policy classified").pipe( Effect.annotateLogs({ service: "Resolver", @@ -200,34 +154,18 @@ const resolveWith = ( policy, }), ); - - const intentPair = new IntentPair({ - intent: category, - intentPolicy: policy, - }); - - const resolvedResponse = yield* resolveIntentPair(intentPair, userProviders); - - if (prompt.source === "per_system_prompt") { - yield* setResolvedResponse(userId, prompt.text, resolvedResponse); - - yield* Effect.logInfo("Prompt cached.").pipe( - Effect.annotateLogs({ - service: "Resolver", - operation: "resolveWith", - message: "Cache event.", - }), - ); - } - - return resolvedResponse; } - const intentPair = new IntentPair({ intent: category, intentPolicy: pair.intentPolicy }); + const intentPair = new IntentPair({ + intent: category as IntentPair["intent"], + intentPolicy: policy as IntentPair["intentPolicy"], + }); + const resolvedResponse = yield* resolveIntentPair(intentPair, userProviders); + const withCategory = resolvedResponse.map((p) => ({ ...p, category: category as string | null })); if (prompt.source === "per_system_prompt") { - yield* setResolvedResponse(userId, prompt.text, resolvedResponse); + yield* setResolvedResponse(userId, prompt.text, withCategory); yield* Effect.logInfo("Prompt cached.").pipe( Effect.annotateLogs({ @@ -238,7 +176,7 @@ const resolveWith = ( ); } - return resolvedResponse; + return withCategory; }); /** diff --git a/packages/resolver/src/resolver/resolve_intent.ts b/packages/resolver/src/resolver/resolve_intent.ts index 1c172ad..0ac00f4 100644 --- a/packages/resolver/src/resolver/resolve_intent.ts +++ b/packages/resolver/src/resolver/resolve_intent.ts @@ -6,23 +6,6 @@ import { getPotentialModelsForIntentPair } from "../data_manager"; import * as Redis from "../redis/index"; import { NoProviderAvailableError } from "../types"; -const getAllProvidersForModel = (modelNameSlug: string) => - Effect.gen(function* () { - return yield* Redis.getProvidersForModel(modelNameSlug); - }); - -const findMatchingMapping = (modelNameSlug: string, userProviderSet: ReadonlySet) => - Effect.gen(function* () { - const providersForModel = yield* getAllProvidersForModel(modelNameSlug); - return providersForModel.filter(({ provider }) => userProviderSet.has(provider)); - }); - -const getAllProvidersForPotentialModels = (modelNameSlugs: readonly string[]) => - Effect.gen(function* () { - const providerArrays = yield* Effect.all(modelNameSlugs.map(getAllProvidersForModel)); - return providerArrays.flat().map(({ provider }) => provider); - }); - export const resolveIntentPair = (pair: IntentPair, userProviders: string[]) => Effect.gen(function* () { yield* Effect.logDebug("Resolving intent pair").pipe( @@ -40,9 +23,16 @@ export const resolveIntentPair = (pair: IntentPair, userProviders: string[]) => const pairs = yield* Effect.map( Effect.all( - potentialModels.map((modelNameSlug) => findMatchingMapping(modelNameSlug, userProviderSet)), + potentialModels.map((modelNameSlug) => + Redis.getProvidersForModel(modelNameSlug).pipe( + Effect.map((providers) => + providers.filter(({ provider }) => userProviderSet.has(provider)), + ), + ), + ), ), - (results) => results.flat(), + (results) => + results.flat().map((p) => ({ ...p, category: pair.intent as string | null })), ); if (pairs.length > 0) { @@ -58,8 +48,8 @@ export const resolveIntentPair = (pair: IntentPair, userProviders: string[]) => } const availableProviders = yield* Effect.map( - getAllProvidersForPotentialModels(potentialModels), - Arr.dedupe, + Effect.all(potentialModels.map((m) => Redis.getProvidersForModel(m))), + (results) => Arr.dedupe(results.flat().map(({ provider }) => provider)), ); yield* Effect.logWarning("No matching provider found for intent").pipe( diff --git a/packages/resolver/src/resolver/resolve_provider_model.ts b/packages/resolver/src/resolver/resolve_provider_model.ts index d852056..948e792 100644 --- a/packages/resolver/src/resolver/resolve_provider_model.ts +++ b/packages/resolver/src/resolver/resolve_provider_model.ts @@ -2,7 +2,7 @@ import { Effect } from "effect"; import type { ProviderModelPair } from "../types"; export const resolveProviderModelPair = (pair: ProviderModelPair) => - Effect.succeed([{ model: pair.model, provider: pair.provider }]).pipe( + Effect.succeed([{ model: pair.model, provider: pair.provider, category: null as string | null }]).pipe( Effect.tap((resolved) => Effect.logDebug("Provider/model passthrough resolved").pipe( Effect.annotateLogs({ diff --git a/packages/resolver/src/types.ts b/packages/resolver/src/types.ts index b7966fe..e1e906c 100644 --- a/packages/resolver/src/types.ts +++ b/packages/resolver/src/types.ts @@ -1,56 +1,17 @@ -import { Data, Schema } from "effect"; - -export const Intent = Schema.Literal( - "auto", - "academia", - "finance", - "health", - "legal", - "marketing", - "programming", - "roleplay", - "science", - "seo", - "technology", - "translation", - "trivia", -); -export type Intent = Schema.Schema.Type; - -export const INTENTS: ReadonlyArray = Intent.literals; - -/** All intents except "auto", used for data fetching categories. */ -export const CATEGORIES: ReadonlyArray> = INTENTS.filter( - (i): i is Exclude => i !== "auto", -); - -export const IntentPolicy = Schema.Literal( - "auto", - "most-popular", - "pricing-low-to-high", - "pricing-high-to-low", - "context-high-to-low", - "latency-low-to-high", - "throughput-high-to-low", -); -export type IntentPolicy = Schema.Schema.Type; - -export const INTENT_POLICIES: ReadonlyArray = IntentPolicy.literals; - -/** All intents except "auto", used for data fetching orders. */ -export const ORDERS: ReadonlyArray> = INTENT_POLICIES.filter( - (i): i is Exclude => i !== "auto", -); - -export class IntentPair extends Data.TaggedClass("IntentPair")<{ - readonly intent: Intent; - readonly intentPolicy: IntentPolicy; -}> {} - -export class ProviderModelPair extends Data.TaggedClass("ProviderModelPair")<{ - readonly model: string; - readonly provider: string; -}> {} +import { Data } from "effect"; + +export { + Intent, + type Intent as IntentType, + INTENTS, + CATEGORIES, + IntentPolicy, + type IntentPolicy as IntentPolicyType, + INTENT_POLICIES, + ORDERS, + IntentPair, + ProviderModelPair, +} from "common"; export class ResolveError extends Data.TaggedError("ResolveError")<{ readonly reason: "InvalidModelType" | "UnsupportedInputType";