diff --git a/lib/chat/__tests__/handleChatStream.test.ts b/lib/chat/__tests__/handleChatStream.test.ts index 5943feed..b29127dc 100644 --- a/lib/chat/__tests__/handleChatStream.test.ts +++ b/lib/chat/__tests__/handleChatStream.test.ts @@ -162,6 +162,11 @@ describe("handleChatStream", () => { expect(mockCreateUIMessageStream).toHaveBeenCalled(); expect(mockCreateUIMessageStreamResponse).toHaveBeenCalledWith({ stream: mockStream, + headers: { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", + "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Requested-With, x-api-key", + }, }); expect(result).toBe(mockResponse); }); diff --git a/lib/chat/__tests__/setupToolsForRequest.test.ts b/lib/chat/__tests__/setupToolsForRequest.test.ts index d4107b30..c2d36bb8 100644 --- a/lib/chat/__tests__/setupToolsForRequest.test.ts +++ b/lib/chat/__tests__/setupToolsForRequest.test.ts @@ -2,28 +2,8 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { ChatRequestBody } from "../validateChatRequest"; // Mock external dependencies -vi.mock("@ai-sdk/mcp", () => ({ - experimental_createMCPClient: vi.fn(), -})); - -vi.mock("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ - StreamableHTTPClientTransport: vi.fn().mockImplementation(() => ({})), -})); - -vi.mock("@modelcontextprotocol/sdk/server/mcp.js", () => ({ - McpServer: vi.fn().mockImplementation(() => ({ - connect: vi.fn(), - })), -})); - -vi.mock("@modelcontextprotocol/sdk/inMemory.js", () => ({ - InMemoryTransport: { - createLinkedPair: vi.fn().mockReturnValue([{}, {}]), - }, -})); - -vi.mock("@/lib/mcp/tools", () => ({ - registerAllTools: vi.fn(), +vi.mock("@/lib/mcp/getMcpTools", () => ({ + getMcpTools: vi.fn(), })); vi.mock("@/lib/agents/googleSheetsAgent", () => ({ @@ -32,10 +12,10 @@ vi.mock("@/lib/agents/googleSheetsAgent", () => ({ // Import after mocks import { setupToolsForRequest } from "../setupToolsForRequest"; -import { experimental_createMCPClient } from "@ai-sdk/mcp"; +import { getMcpTools } from "@/lib/mcp/getMcpTools"; import { getGoogleSheetsTools } from "@/lib/agents/googleSheetsAgent"; -const mockCreateMCPClient = vi.mocked(experimental_createMCPClient); +const mockGetMcpTools = vi.mocked(getMcpTools); const mockGetGoogleSheetsTools = vi.mocked(getGoogleSheetsTools); describe("setupToolsForRequest", () => { @@ -56,32 +36,32 @@ describe("setupToolsForRequest", () => { beforeEach(() => { vi.clearAllMocks(); - // Default mock for MCP client - mockCreateMCPClient.mockResolvedValue({ - tools: vi.fn().mockResolvedValue(mockMcpTools), - } as any); + // Default mock for MCP tools + mockGetMcpTools.mockResolvedValue(mockMcpTools); // Default mock for Google Sheets tools - returns login tool (not authenticated) mockGetGoogleSheetsTools.mockResolvedValue(mockGoogleSheetsLoginTool); }); describe("MCP tools integration", () => { - it("creates MCP client with correct URL", async () => { + it("calls getMcpTools with authToken", async () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], }; await setupToolsForRequest(body); - expect(mockCreateMCPClient).toHaveBeenCalled(); + expect(mockGetMcpTools).toHaveBeenCalledWith("test-token-123"); }); it("fetches tools from MCP client", async () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], }; @@ -91,7 +71,7 @@ describe("setupToolsForRequest", () => { expect(result).toHaveProperty("tool2"); }); - it("passes accountId to MCP client via authenticated transport", async () => { + it("skips MCP tools when authToken is not provided", async () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, @@ -100,25 +80,7 @@ describe("setupToolsForRequest", () => { await setupToolsForRequest(body); - // Verify MCP client was created with a transport that includes auth info - expect(mockCreateMCPClient).toHaveBeenCalledWith( - expect.objectContaining({ - transport: expect.any(Object), - }), - ); - }); - - it("passes orgId to MCP client via authenticated transport", async () => { - const body: ChatRequestBody = { - accountId: "account-123", - orgId: "org-456", - messages: [{ id: "1", role: "user", content: "Hello" }], - }; - - await setupToolsForRequest(body); - - // Verify MCP client was created - expect(mockCreateMCPClient).toHaveBeenCalled(); + expect(mockGetMcpTools).not.toHaveBeenCalled(); }); }); @@ -127,6 +89,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Create a spreadsheet" }], }; @@ -141,6 +104,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Create a spreadsheet" }], }; @@ -156,6 +120,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Create a spreadsheet" }], }; @@ -172,6 +137,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], }; @@ -185,11 +151,9 @@ describe("setupToolsForRequest", () => { }); it("Google Sheets tools take precedence over MCP tools with same name", async () => { - mockCreateMCPClient.mockResolvedValue({ - tools: vi.fn().mockResolvedValue({ - googlesheets_create: { description: "MCP version", parameters: {} }, - }), - } as any); + mockGetMcpTools.mockResolvedValue({ + googlesheets_create: { description: "MCP version", parameters: {} }, + }); mockGetGoogleSheetsTools.mockResolvedValue({ googlesheets_create: { description: "Composio version", parameters: {} }, @@ -198,6 +162,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], }; @@ -215,6 +180,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], excludeTools: ["tool1"], }; @@ -231,6 +197,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], excludeTools: ["tool1", "googlesheets_create"], }; @@ -247,6 +214,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], }; @@ -260,6 +228,7 @@ describe("setupToolsForRequest", () => { const body: ChatRequestBody = { accountId: "account-123", orgId: null, + authToken: "test-token-123", messages: [{ id: "1", role: "user", content: "Hello" }], excludeTools: [], }; diff --git a/lib/chat/handleChatStream.ts b/lib/chat/handleChatStream.ts index 56cfb963..fe971374 100644 --- a/lib/chat/handleChatStream.ts +++ b/lib/chat/handleChatStream.ts @@ -58,7 +58,7 @@ export async function handleChatStream(request: NextRequest): Promise }, }); - return createUIMessageStreamResponse({ stream }); + return createUIMessageStreamResponse({ stream, headers: getCorsHeaders() }); } catch (e) { console.error("/api/chat Global error:", e); return NextResponse.json( diff --git a/lib/chat/setupToolsForRequest.ts b/lib/chat/setupToolsForRequest.ts index a5fac3d0..15746300 100644 --- a/lib/chat/setupToolsForRequest.ts +++ b/lib/chat/setupToolsForRequest.ts @@ -1,36 +1,23 @@ import { ToolSet } from "ai"; import { filterExcludedTools } from "./filterExcludedTools"; import { ChatRequestBody } from "./validateChatRequest"; -import { experimental_createMCPClient as createMCPClient } from "@ai-sdk/mcp"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js"; -import { registerAllTools } from "@/lib/mcp/tools"; import { getGoogleSheetsTools } from "@/lib/agents/googleSheetsAgent"; +import { getMcpTools } from "@/lib/mcp/getMcpTools"; /** * Sets up and filters tools for a chat request. * Aggregates tools from: - * - MCP server (in-process via in-memory transport, no HTTP overhead) + * - MCP server (via HTTP transport to /api/mcp for proper auth) * - Google Sheets (via Composio integration) * * @param body - The chat request body * @returns Filtered tool set ready for use */ export async function setupToolsForRequest(body: ChatRequestBody): Promise { - const { excludeTools } = body; + const { excludeTools, authToken } = body; - // Create in-memory MCP server and client (no HTTP call needed) - const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); - - const server = new McpServer({ - name: "recoup-mcp", - version: "0.0.1", - }); - registerAllTools(server); - await server.connect(serverTransport); - - const mcpClient = await createMCPClient({ transport: clientTransport }); - const mcpClientTools = (await mcpClient.tools()) as ToolSet; + // Only fetch MCP tools if we have an auth token + const mcpClientTools = authToken ? await getMcpTools(authToken) : {}; // Fetch Google Sheets tools (authenticated tools or login tool) const googleSheetsTools = await getGoogleSheetsTools(body); diff --git a/lib/chat/validateChatRequest.ts b/lib/chat/validateChatRequest.ts index c1a7ec6b..114656c3 100644 --- a/lib/chat/validateChatRequest.ts +++ b/lib/chat/validateChatRequest.ts @@ -48,6 +48,7 @@ type BaseChatRequestBody = z.infer; export type ChatRequestBody = BaseChatRequestBody & { accountId: string; orgId: string | null; + authToken?: string; }; /** @@ -192,10 +193,14 @@ export async function validateChatRequest( memoryId: lastMessage.id, }); + // Extract the auth token to forward to MCP server + const authToken = hasApiKey ? apiKey! : authHeader!.replace(/^Bearer\s+/i, ""); + return { ...validatedBody, accountId, orgId, roomId: finalRoomId, + authToken, } as ChatRequestBody; } diff --git a/lib/mcp/__tests__/getMcpTools.test.ts b/lib/mcp/__tests__/getMcpTools.test.ts new file mode 100644 index 00000000..0422fbc0 --- /dev/null +++ b/lib/mcp/__tests__/getMcpTools.test.ts @@ -0,0 +1,63 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +vi.mock("@ai-sdk/mcp", () => ({ + experimental_createMCPClient: vi.fn(), +})); + +vi.mock("@/lib/networking/getBaseUrl", () => ({ + getBaseUrl: vi.fn().mockReturnValue("https://test.vercel.app"), +})); + +import { getMcpTools } from "../getMcpTools"; +import { experimental_createMCPClient } from "@ai-sdk/mcp"; + +const mockCreateMCPClient = vi.mocked(experimental_createMCPClient); + +describe("getMcpTools", () => { + const mockTools = { + tool1: { description: "Tool 1", parameters: {} }, + tool2: { description: "Tool 2", parameters: {} }, + }; + + beforeEach(() => { + vi.clearAllMocks(); + + mockCreateMCPClient.mockResolvedValue({ + tools: vi.fn().mockResolvedValue(mockTools), + } as any); + }); + + it("creates MCP client with HTTP transport config", async () => { + await getMcpTools("test-token"); + + expect(mockCreateMCPClient).toHaveBeenCalledWith({ + transport: { + type: "http", + url: "https://test.vercel.app/mcp", + headers: { + Authorization: "Bearer test-token", + }, + }, + }); + }); + + it("returns tools from MCP client", async () => { + const result = await getMcpTools("test-token"); + + expect(result).toEqual(mockTools); + }); + + it("passes different auth tokens correctly", async () => { + await getMcpTools("different-token"); + + expect(mockCreateMCPClient).toHaveBeenCalledWith({ + transport: { + type: "http", + url: "https://test.vercel.app/mcp", + headers: { + Authorization: "Bearer different-token", + }, + }, + }); + }); +}); diff --git a/lib/mcp/getMcpTools.ts b/lib/mcp/getMcpTools.ts new file mode 100644 index 00000000..2e1fae6a --- /dev/null +++ b/lib/mcp/getMcpTools.ts @@ -0,0 +1,23 @@ +import { ToolSet } from "ai"; +import { experimental_createMCPClient as createMCPClient } from "@ai-sdk/mcp"; +import { getBaseUrl } from "@/lib/networking/getBaseUrl"; + +/** + * Fetches MCP tools via HTTP transport with authentication. + * + * @param authToken - The auth token to use for MCP endpoint authentication + * @returns The MCP tools as a ToolSet + */ +export async function getMcpTools(authToken: string): Promise { + const mcpClient = await createMCPClient({ + transport: { + type: "http", + url: `${getBaseUrl()}/mcp`, + headers: { + Authorization: `Bearer ${authToken}`, + }, + }, + }); + + return (await mcpClient.tools()) as ToolSet; +} diff --git a/lib/networking/__tests__/getBaseUrl.test.ts b/lib/networking/__tests__/getBaseUrl.test.ts new file mode 100644 index 00000000..fa678ea1 --- /dev/null +++ b/lib/networking/__tests__/getBaseUrl.test.ts @@ -0,0 +1,39 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { getBaseUrl } from "../getBaseUrl"; + +describe("getBaseUrl", () => { + const originalEnv = process.env; + + beforeEach(() => { + vi.resetModules(); + process.env = { ...originalEnv }; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + it("returns HTTPS URL when VERCEL_URL is set", () => { + process.env.VERCEL_URL = "my-app.vercel.app"; + + const result = getBaseUrl(); + + expect(result).toBe("https://my-app.vercel.app"); + }); + + it("returns localhost when VERCEL_URL is not set", () => { + delete process.env.VERCEL_URL; + + const result = getBaseUrl(); + + expect(result).toBe("http://localhost:3000"); + }); + + it("returns localhost when VERCEL_URL is empty string", () => { + process.env.VERCEL_URL = ""; + + const result = getBaseUrl(); + + expect(result).toBe("http://localhost:3000"); + }); +}); diff --git a/lib/networking/getBaseUrl.ts b/lib/networking/getBaseUrl.ts new file mode 100644 index 00000000..112d94e5 --- /dev/null +++ b/lib/networking/getBaseUrl.ts @@ -0,0 +1,12 @@ +/** + * Gets the base URL for the current API server. + * Uses VERCEL_URL in Vercel deployments, falls back to localhost. + * + * @returns The base URL string + */ +export function getBaseUrl(): string { + if (process.env.VERCEL_URL) { + return `https://${process.env.VERCEL_URL}`; + } + return "http://localhost:3000"; +} diff --git a/lib/networking/getCorsHeaders.ts b/lib/networking/getCorsHeaders.ts index 0a6c3a02..233b32df 100644 --- a/lib/networking/getCorsHeaders.ts +++ b/lib/networking/getCorsHeaders.ts @@ -7,6 +7,6 @@ export function getCorsHeaders(): Record { return { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH", - "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Requested-With", + "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Requested-With, x-api-key", }; }