diff --git a/lib/chat/__tests__/handleChatGenerate.test.ts b/lib/chat/__tests__/handleChatGenerate.test.ts index 04b0a05..1a90d49 100644 --- a/lib/chat/__tests__/handleChatGenerate.test.ts +++ b/lib/chat/__tests__/handleChatGenerate.test.ts @@ -1,6 +1,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextResponse } from "next/server"; +import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; +import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; +import { setupChatRequest } from "@/lib/chat/setupChatRequest"; +import { handleChatCompletion } from "@/lib/chat/handleChatCompletion"; +import { generateText } from "ai"; +import { handleChatGenerate } from "../handleChatGenerate"; + // Mock all dependencies before importing the module under test vi.mock("@/lib/auth/getApiKeyAccountId", () => ({ getApiKeyAccountId: vi.fn(), @@ -26,26 +33,27 @@ vi.mock("@/lib/chat/setupChatRequest", () => ({ setupChatRequest: vi.fn(), })); +vi.mock("@/lib/chat/handleChatCompletion", () => ({ + handleChatCompletion: vi.fn(), +})); + vi.mock("ai", () => ({ generateText: vi.fn(), })); -import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; -import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; -import { setupChatRequest } from "@/lib/chat/setupChatRequest"; -import { generateText } from "ai"; -import { handleChatGenerate } from "../handleChatGenerate"; - const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId); const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId); const mockSetupChatRequest = vi.mocked(setupChatRequest); +const mockHandleChatCompletion = vi.mocked(handleChatCompletion); const mockGenerateText = vi.mocked(generateText); // Helper to create mock NextRequest -function createMockRequest( - body: unknown, - headers: Record = {}, -): Request { +/** + * + * @param body + * @param headers + */ +function createMockRequest(body: unknown, headers: Record = {}): Request { return { json: () => Promise.resolve(body), headers: { @@ -58,6 +66,8 @@ function createMockRequest( describe("handleChatGenerate", () => { beforeEach(() => { vi.clearAllMocks(); + // Default mock for handleChatCompletion to return a resolved Promise + mockHandleChatCompletion.mockResolvedValue(); }); afterEach(() => { @@ -68,10 +78,7 @@ describe("handleChatGenerate", () => { it("returns 400 error when neither messages nor prompt is provided", async () => { mockGetApiKeyAccountId.mockResolvedValue("account-123"); - const request = createMockRequest( - { roomId: "room-123" }, - { "x-api-key": "test-key" }, - ); + const request = createMockRequest({ roomId: "room-123" }, { "x-api-key": "test-key" }); const result = await handleChatGenerate(request as any); @@ -122,10 +129,7 @@ describe("handleChatGenerate", () => { }, } as any); - const request = createMockRequest( - { prompt: "Hello, world!" }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ prompt: "Hello, world!" }, { "x-api-key": "valid-key" }); const result = await handleChatGenerate(request as any); @@ -157,10 +161,7 @@ describe("handleChatGenerate", () => { } as any); const messages = [{ role: "user", content: "Hello" }]; - const request = createMockRequest( - { messages }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ messages }, { "x-api-key": "valid-key" }); await handleChatGenerate(request as any); @@ -237,10 +238,7 @@ describe("handleChatGenerate", () => { response: { messages: [], headers: {}, body: null }, } as any); - const request = createMockRequest( - { prompt: "Hello" }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); const result = await handleChatGenerate(request as any); @@ -256,10 +254,7 @@ describe("handleChatGenerate", () => { mockGetApiKeyAccountId.mockResolvedValue("account-123"); mockSetupChatRequest.mockRejectedValue(new Error("Setup failed")); - const request = createMockRequest( - { prompt: "Hello" }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); const result = await handleChatGenerate(request as any); @@ -284,10 +279,7 @@ describe("handleChatGenerate", () => { mockGenerateText.mockRejectedValue(new Error("Generation failed")); - const request = createMockRequest( - { prompt: "Hello" }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); const result = await handleChatGenerate(request as any); @@ -336,4 +328,167 @@ describe("handleChatGenerate", () => { ); }); }); + + describe("chat completion handling", () => { + it("calls handleChatCompletion after text generation", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Hello!", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { messages: [], headers: {}, body: null }, + } as any); + + mockHandleChatCompletion.mockResolvedValue(); + + const messages = [{ id: "msg-1", role: "user", parts: [{ type: "text", text: "Hi" }] }]; + const request = createMockRequest( + { messages, roomId: "room-123" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatGenerate(request as any); + + expect(mockHandleChatCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + messages, + roomId: "room-123", + accountId: "account-123", + }), + expect.arrayContaining([ + expect.objectContaining({ + role: "assistant", + parts: [{ type: "text", text: "Hello!" }], + }), + ]), + ); + }); + + it("passes artistId to handleChatCompletion when provided", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Hello!", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { messages: [], headers: {}, body: null }, + } as any); + + mockHandleChatCompletion.mockResolvedValue(); + + const messages = [{ id: "msg-1", role: "user", parts: [{ type: "text", text: "Hi" }] }]; + const request = createMockRequest( + { messages, roomId: "room-123", artistId: "artist-456" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatGenerate(request as any); + + expect(mockHandleChatCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + artistId: "artist-456", + }), + expect.arrayContaining([ + expect.objectContaining({ + role: "assistant", + parts: [{ type: "text", text: "Hello!" }], + }), + ]), + ); + }); + + it("does not throw when handleChatCompletion fails (graceful handling)", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Hello!", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { + messages: [{ id: "resp-1", role: "assistant", parts: [] }], + headers: {}, + body: null, + }, + } as any); + + // Make handleChatCompletion throw an error + mockHandleChatCompletion.mockRejectedValue(new Error("Completion handling failed")); + + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); + + // Should still return 200 - completion handling failure should not affect response + const result = await handleChatGenerate(request as any); + expect(result.status).toBe(200); + }); + + it("calls handleChatCompletion even when validation skips it for missing roomId", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + const mockResponseMessages = [ + { + id: "resp-1", + role: "assistant", + parts: [{ type: "text", text: "Hello!" }], + }, + ]; + + mockSetupChatRequest.mockResolvedValue({ + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + mockGenerateText.mockResolvedValue({ + text: "Hello!", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + response: { messages: mockResponseMessages, headers: {}, body: null }, + } as any); + + mockHandleChatCompletion.mockResolvedValue(); + + // No roomId provided + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); + + await handleChatGenerate(request as any); + + // handleChatCompletion should still be called (it handles room creation internally) + expect(mockHandleChatCompletion).toHaveBeenCalled(); + }); + }); }); diff --git a/lib/chat/__tests__/handleChatStream.test.ts b/lib/chat/__tests__/handleChatStream.test.ts index b78918e..147cd4a 100644 --- a/lib/chat/__tests__/handleChatStream.test.ts +++ b/lib/chat/__tests__/handleChatStream.test.ts @@ -1,6 +1,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import { NextResponse } from "next/server"; +import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; +import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; +import { setupChatRequest } from "@/lib/chat/setupChatRequest"; +import { handleChatCompletion } from "@/lib/chat/handleChatCompletion"; +import { createUIMessageStream, createUIMessageStreamResponse } from "ai"; +import { handleChatStream } from "../handleChatStream"; + // Mock all dependencies before importing the module under test vi.mock("@/lib/auth/getApiKeyAccountId", () => ({ getApiKeyAccountId: vi.fn(), @@ -26,28 +33,29 @@ vi.mock("@/lib/chat/setupChatRequest", () => ({ setupChatRequest: vi.fn(), })); +vi.mock("@/lib/chat/handleChatCompletion", () => ({ + handleChatCompletion: vi.fn(), +})); + vi.mock("ai", () => ({ createUIMessageStream: vi.fn(), createUIMessageStreamResponse: vi.fn(), })); -import { getApiKeyAccountId } from "@/lib/auth/getApiKeyAccountId"; -import { validateOverrideAccountId } from "@/lib/accounts/validateOverrideAccountId"; -import { setupChatRequest } from "@/lib/chat/setupChatRequest"; -import { createUIMessageStream, createUIMessageStreamResponse } from "ai"; -import { handleChatStream } from "../handleChatStream"; - const mockGetApiKeyAccountId = vi.mocked(getApiKeyAccountId); const mockValidateOverrideAccountId = vi.mocked(validateOverrideAccountId); const mockSetupChatRequest = vi.mocked(setupChatRequest); +const mockHandleChatCompletion = vi.mocked(handleChatCompletion); const mockCreateUIMessageStream = vi.mocked(createUIMessageStream); const mockCreateUIMessageStreamResponse = vi.mocked(createUIMessageStreamResponse); // Helper to create mock NextRequest -function createMockRequest( - body: unknown, - headers: Record = {}, -): Request { +/** + * + * @param body + * @param headers + */ +function createMockRequest(body: unknown, headers: Record = {}): Request { return { json: () => Promise.resolve(body), headers: { @@ -60,6 +68,8 @@ function createMockRequest( describe("handleChatStream", () => { beforeEach(() => { vi.clearAllMocks(); + // Default mock for handleChatCompletion to return a resolved Promise + mockHandleChatCompletion.mockResolvedValue(); }); afterEach(() => { @@ -70,10 +80,7 @@ describe("handleChatStream", () => { it("returns 400 error when neither messages nor prompt is provided", async () => { mockGetApiKeyAccountId.mockResolvedValue("account-123"); - const request = createMockRequest( - { roomId: "room-123" }, - { "x-api-key": "test-key" }, - ); + const request = createMockRequest({ roomId: "room-123" }, { "x-api-key": "test-key" }); const result = await handleChatStream(request as any); @@ -124,10 +131,7 @@ describe("handleChatStream", () => { const mockResponse = new Response(mockStream); mockCreateUIMessageStreamResponse.mockReturnValue(mockResponse); - const request = createMockRequest( - { prompt: "Hello, world!" }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ prompt: "Hello, world!" }, { "x-api-key": "valid-key" }); const result = await handleChatStream(request as any); @@ -166,10 +170,7 @@ describe("handleChatStream", () => { mockCreateUIMessageStreamResponse.mockReturnValue(new Response(mockStream)); const messages = [{ role: "user", content: "Hello" }]; - const request = createMockRequest( - { messages }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ messages }, { "x-api-key": "valid-key" }); await handleChatStream(request as any); @@ -236,10 +237,7 @@ describe("handleChatStream", () => { mockGetApiKeyAccountId.mockResolvedValue("account-123"); mockSetupChatRequest.mockRejectedValue(new Error("Setup failed")); - const request = createMockRequest( - { prompt: "Hello" }, - { "x-api-key": "valid-key" }, - ); + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); const result = await handleChatStream(request as any); @@ -294,4 +292,176 @@ describe("handleChatStream", () => { ); }); }); + + describe("chat completion handling", () => { + it("calls handleChatCompletion after streaming completes", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + const mockAgent = { + stream: vi.fn().mockResolvedValue({ + toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + text: Promise.resolve("Hello!"), + }), + tools: {}, + }; + + mockSetupChatRequest.mockResolvedValue({ + agent: mockAgent, + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + // Capture the execute callback and run it + let capturedExecute: ((options: { writer: { merge: () => void } }) => Promise) | null = + null; + mockCreateUIMessageStream.mockImplementation( + (options: { execute: typeof capturedExecute }) => { + capturedExecute = options.execute; + return new ReadableStream(); + }, + ); + mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream())); + mockHandleChatCompletion.mockResolvedValue(); + + const messages = [{ id: "msg-1", role: "user", parts: [{ type: "text", text: "Hi" }] }]; + const request = createMockRequest( + { messages, roomId: "room-123" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatStream(request as any); + + // Execute the captured callback to simulate stream completion + if (capturedExecute) { + await capturedExecute({ writer: { merge: vi.fn() } }); + } + + expect(mockHandleChatCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + messages, + roomId: "room-123", + accountId: "account-123", + }), + expect.arrayContaining([ + expect.objectContaining({ + role: "assistant", + parts: [{ type: "text", text: "Hello!" }], + }), + ]), + ); + }); + + it("passes artistId to handleChatCompletion when provided", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + const mockAgent = { + stream: vi.fn().mockResolvedValue({ + toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + text: Promise.resolve("Hello!"), + }), + tools: {}, + }; + + mockSetupChatRequest.mockResolvedValue({ + agent: mockAgent, + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + let capturedExecute: ((options: { writer: { merge: () => void } }) => Promise) | null = + null; + mockCreateUIMessageStream.mockImplementation( + (options: { execute: typeof capturedExecute }) => { + capturedExecute = options.execute; + return new ReadableStream(); + }, + ); + mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream())); + mockHandleChatCompletion.mockResolvedValue(); + + const messages = [{ id: "msg-1", role: "user", parts: [{ type: "text", text: "Hi" }] }]; + const request = createMockRequest( + { messages, roomId: "room-123", artistId: "artist-456" }, + { "x-api-key": "valid-key" }, + ); + + await handleChatStream(request as any); + + if (capturedExecute) { + await capturedExecute({ writer: { merge: vi.fn() } }); + } + + expect(mockHandleChatCompletion).toHaveBeenCalledWith( + expect.objectContaining({ + artistId: "artist-456", + }), + expect.arrayContaining([ + expect.objectContaining({ + role: "assistant", + parts: [{ type: "text", text: "Hello!" }], + }), + ]), + ); + }); + + it("does not throw when handleChatCompletion fails (graceful handling)", async () => { + mockGetApiKeyAccountId.mockResolvedValue("account-123"); + + const mockAgent = { + stream: vi.fn().mockResolvedValue({ + toUIMessageStream: vi.fn().mockReturnValue(new ReadableStream()), + usage: Promise.resolve({ inputTokens: 100, outputTokens: 50 }), + text: Promise.resolve("Hello!"), + }), + tools: {}, + }; + + mockSetupChatRequest.mockResolvedValue({ + agent: mockAgent, + model: "gpt-4", + instructions: "test", + system: "test", + messages: [], + experimental_generateMessageId: vi.fn(), + tools: {}, + providerOptions: {}, + } as any); + + let capturedExecute: ((options: { writer: { merge: () => void } }) => Promise) | null = + null; + mockCreateUIMessageStream.mockImplementation( + (options: { execute: typeof capturedExecute }) => { + capturedExecute = options.execute; + return new ReadableStream(); + }, + ); + mockCreateUIMessageStreamResponse.mockReturnValue(new Response(new ReadableStream())); + + // Make handleChatCompletion throw an error + mockHandleChatCompletion.mockRejectedValue(new Error("Completion handling failed")); + + const request = createMockRequest({ prompt: "Hello" }, { "x-api-key": "valid-key" }); + + // Should not throw + const result = await handleChatStream(request as any); + expect(result).toBeInstanceOf(Response); + + // Execute callback should not throw either + if (capturedExecute) { + await expect(capturedExecute({ writer: { merge: vi.fn() } })).resolves.toBeUndefined(); + } + }); + }); }); diff --git a/lib/chat/handleChatGenerate.ts b/lib/chat/handleChatGenerate.ts index d708bcf..6d5f119 100644 --- a/lib/chat/handleChatGenerate.ts +++ b/lib/chat/handleChatGenerate.ts @@ -1,8 +1,10 @@ import { NextRequest, NextResponse } from "next/server"; -import { generateText } from "ai"; +import { generateText, type UIMessage } from "ai"; import { validateChatRequest } from "./validateChatRequest"; import { setupChatRequest } from "./setupChatRequest"; +import { handleChatCompletion } from "./handleChatCompletion"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; +import generateUUID from "@/lib/uuid/generateUUID"; /** * Handles a non-streaming chat generate request. @@ -28,8 +30,18 @@ export async function handleChatGenerate(request: NextRequest): Promise { + // Silently catch - handleChatCompletion handles its own error reporting + }); return NextResponse.json( { diff --git a/lib/chat/handleChatStream.ts b/lib/chat/handleChatStream.ts index 396a66e..b3c9611 100644 --- a/lib/chat/handleChatStream.ts +++ b/lib/chat/handleChatStream.ts @@ -1,7 +1,8 @@ import { NextRequest, NextResponse } from "next/server"; -import { createUIMessageStream, createUIMessageStreamResponse } from "ai"; +import { createUIMessageStream, createUIMessageStreamResponse, type UIMessage } from "ai"; import { validateChatRequest } from "./validateChatRequest"; import { setupChatRequest } from "./setupChatRequest"; +import { handleChatCompletion } from "./handleChatCompletion"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import generateUUID from "@/lib/uuid/generateUUID"; @@ -30,14 +31,26 @@ export async function handleChatStream(request: NextRequest): Promise const stream = createUIMessageStream({ originalMessages: body.messages, generateId: generateUUID, - execute: async (options) => { + execute: async options => { const { writer } = options; const result = await agent.stream(chatConfig); writer.merge(result.toUIMessageStream()); - // Note: Credit handling and chat completion handling will be added - // as part of the handleChatCredits and handleChatCompletion migrations + + // Construct UIMessage from streaming result for handleChatCompletion + const text = await result.text; + const assistantMessage: UIMessage = { + id: generateUUID(), + role: "assistant", + parts: [{ type: "text", text }], + }; + + // Handle post-completion tasks (room creation, memory storage, notifications) + // Errors are handled gracefully within handleChatCompletion + handleChatCompletion(body, [assistantMessage]).catch(() => { + // Silently catch - handleChatCompletion handles its own error reporting + }); }, - onError: (e) => { + onError: e => { console.error("/api/chat onError:", e); return JSON.stringify({ status: "error",