Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/chat/__tests__/handleChatStream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ describe("handleChatStream", () => {
expect(mockCreateUIMessageStream).toHaveBeenCalled();
expect(mockCreateUIMessageStreamResponse).toHaveBeenCalledWith({
stream: mockStream,
headers: {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS, PATCH",
"Access-Control-Allow-Headers": "Content-Type, Authorization, X-Requested-With, x-api-key",
},
});
expect(result).toBe(mockResponse);
});
Expand Down
79 changes: 24 additions & 55 deletions lib/chat/__tests__/setupToolsForRequest.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,8 @@ import { describe, it, expect, vi, beforeEach } from "vitest";
import { ChatRequestBody } from "../validateChatRequest";

// Mock external dependencies
vi.mock("@ai-sdk/mcp", () => ({
experimental_createMCPClient: vi.fn(),
}));

vi.mock("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({
StreamableHTTPClientTransport: vi.fn().mockImplementation(() => ({})),
}));

vi.mock("@modelcontextprotocol/sdk/server/mcp.js", () => ({
McpServer: vi.fn().mockImplementation(() => ({
connect: vi.fn(),
})),
}));

vi.mock("@modelcontextprotocol/sdk/inMemory.js", () => ({
InMemoryTransport: {
createLinkedPair: vi.fn().mockReturnValue([{}, {}]),
},
}));

vi.mock("@/lib/mcp/tools", () => ({
registerAllTools: vi.fn(),
vi.mock("@/lib/mcp/getMcpTools", () => ({
getMcpTools: vi.fn(),
}));

vi.mock("@/lib/agents/googleSheetsAgent", () => ({
Expand All @@ -32,10 +12,10 @@ vi.mock("@/lib/agents/googleSheetsAgent", () => ({

// Import after mocks
import { setupToolsForRequest } from "../setupToolsForRequest";
import { experimental_createMCPClient } from "@ai-sdk/mcp";
import { getMcpTools } from "@/lib/mcp/getMcpTools";
import { getGoogleSheetsTools } from "@/lib/agents/googleSheetsAgent";

const mockCreateMCPClient = vi.mocked(experimental_createMCPClient);
const mockGetMcpTools = vi.mocked(getMcpTools);
const mockGetGoogleSheetsTools = vi.mocked(getGoogleSheetsTools);

describe("setupToolsForRequest", () => {
Expand All @@ -56,32 +36,32 @@ describe("setupToolsForRequest", () => {
beforeEach(() => {
vi.clearAllMocks();

// Default mock for MCP client
mockCreateMCPClient.mockResolvedValue({
tools: vi.fn().mockResolvedValue(mockMcpTools),
} as any);
// Default mock for MCP tools
mockGetMcpTools.mockResolvedValue(mockMcpTools);

// Default mock for Google Sheets tools - returns login tool (not authenticated)
mockGetGoogleSheetsTools.mockResolvedValue(mockGoogleSheetsLoginTool);
});

describe("MCP tools integration", () => {
it("creates MCP client with correct URL", async () => {
it("calls getMcpTools with authToken", async () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
};

await setupToolsForRequest(body);

expect(mockCreateMCPClient).toHaveBeenCalled();
expect(mockGetMcpTools).toHaveBeenCalledWith("test-token-123");
});

it("fetches tools from MCP client", async () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
};

Expand All @@ -91,7 +71,7 @@ describe("setupToolsForRequest", () => {
expect(result).toHaveProperty("tool2");
});

it("passes accountId to MCP client via authenticated transport", async () => {
it("skips MCP tools when authToken is not provided", async () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
Expand All @@ -100,25 +80,7 @@ describe("setupToolsForRequest", () => {

await setupToolsForRequest(body);

// Verify MCP client was created with a transport that includes auth info
expect(mockCreateMCPClient).toHaveBeenCalledWith(
expect.objectContaining({
transport: expect.any(Object),
}),
);
});

it("passes orgId to MCP client via authenticated transport", async () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: "org-456",
messages: [{ id: "1", role: "user", content: "Hello" }],
};

await setupToolsForRequest(body);

// Verify MCP client was created
expect(mockCreateMCPClient).toHaveBeenCalled();
expect(mockGetMcpTools).not.toHaveBeenCalled();
});
});

Expand All @@ -127,6 +89,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Create a spreadsheet" }],
};

Expand All @@ -141,6 +104,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Create a spreadsheet" }],
};

Expand All @@ -156,6 +120,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Create a spreadsheet" }],
};

Expand All @@ -172,6 +137,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
};

Expand All @@ -185,11 +151,9 @@ describe("setupToolsForRequest", () => {
});

it("Google Sheets tools take precedence over MCP tools with same name", async () => {
mockCreateMCPClient.mockResolvedValue({
tools: vi.fn().mockResolvedValue({
googlesheets_create: { description: "MCP version", parameters: {} },
}),
} as any);
mockGetMcpTools.mockResolvedValue({
googlesheets_create: { description: "MCP version", parameters: {} },
});

mockGetGoogleSheetsTools.mockResolvedValue({
googlesheets_create: { description: "Composio version", parameters: {} },
Expand All @@ -198,6 +162,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
};

Expand All @@ -215,6 +180,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
excludeTools: ["tool1"],
};
Expand All @@ -231,6 +197,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
excludeTools: ["tool1", "googlesheets_create"],
};
Expand All @@ -247,6 +214,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
};

Expand All @@ -260,6 +228,7 @@ describe("setupToolsForRequest", () => {
const body: ChatRequestBody = {
accountId: "account-123",
orgId: null,
authToken: "test-token-123",
messages: [{ id: "1", role: "user", content: "Hello" }],
excludeTools: [],
};
Expand Down
2 changes: 1 addition & 1 deletion lib/chat/handleChatStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export async function handleChatStream(request: NextRequest): Promise<Response>
},
});

return createUIMessageStreamResponse({ stream });
return createUIMessageStreamResponse({ stream, headers: getCorsHeaders() });
} catch (e) {
console.error("/api/chat Global error:", e);
return NextResponse.json(
Expand Down
23 changes: 5 additions & 18 deletions lib/chat/setupToolsForRequest.ts
Original file line number Diff line number Diff line change
@@ -1,36 +1,23 @@
import { ToolSet } from "ai";
import { filterExcludedTools } from "./filterExcludedTools";
import { ChatRequestBody } from "./validateChatRequest";
import { experimental_createMCPClient as createMCPClient } from "@ai-sdk/mcp";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js";
import { registerAllTools } from "@/lib/mcp/tools";
import { getGoogleSheetsTools } from "@/lib/agents/googleSheetsAgent";
import { getMcpTools } from "@/lib/mcp/getMcpTools";

/**
* Sets up and filters tools for a chat request.
* Aggregates tools from:
* - MCP server (in-process via in-memory transport, no HTTP overhead)
* - MCP server (via HTTP transport to /api/mcp for proper auth)
* - Google Sheets (via Composio integration)
*
* @param body - The chat request body
* @returns Filtered tool set ready for use
*/
export async function setupToolsForRequest(body: ChatRequestBody): Promise<ToolSet> {
const { excludeTools } = body;
const { excludeTools, authToken } = body;

// Create in-memory MCP server and client (no HTTP call needed)
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();

const server = new McpServer({
name: "recoup-mcp",
version: "0.0.1",
});
registerAllTools(server);
await server.connect(serverTransport);

const mcpClient = await createMCPClient({ transport: clientTransport });
const mcpClientTools = (await mcpClient.tools()) as ToolSet;
// Only fetch MCP tools if we have an auth token
const mcpClientTools = authToken ? await getMcpTools(authToken) : {};

// Fetch Google Sheets tools (authenticated tools or login tool)
const googleSheetsTools = await getGoogleSheetsTools(body);
Expand Down
5 changes: 5 additions & 0 deletions lib/chat/validateChatRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type BaseChatRequestBody = z.infer<typeof chatRequestSchema>;
export type ChatRequestBody = BaseChatRequestBody & {
accountId: string;
orgId: string | null;
authToken?: string;
};

/**
Expand Down Expand Up @@ -192,10 +193,14 @@ export async function validateChatRequest(
memoryId: lastMessage.id,
});

// Extract the auth token to forward to MCP server
const authToken = hasApiKey ? apiKey! : authHeader!.replace(/^Bearer\s+/i, "");

return {
...validatedBody,
accountId,
orgId,
roomId: finalRoomId,
authToken,
} as ChatRequestBody;
}
63 changes: 63 additions & 0 deletions lib/mcp/__tests__/getMcpTools.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { describe, it, expect, vi, beforeEach } from "vitest";

vi.mock("@ai-sdk/mcp", () => ({
experimental_createMCPClient: vi.fn(),
}));

vi.mock("@/lib/networking/getBaseUrl", () => ({
getBaseUrl: vi.fn().mockReturnValue("https://test.vercel.app"),
}));

import { getMcpTools } from "../getMcpTools";
import { experimental_createMCPClient } from "@ai-sdk/mcp";

const mockCreateMCPClient = vi.mocked(experimental_createMCPClient);

describe("getMcpTools", () => {
const mockTools = {
tool1: { description: "Tool 1", parameters: {} },
tool2: { description: "Tool 2", parameters: {} },
};

beforeEach(() => {
vi.clearAllMocks();

mockCreateMCPClient.mockResolvedValue({
tools: vi.fn().mockResolvedValue(mockTools),
} as any);
});

it("creates MCP client with HTTP transport config", async () => {
await getMcpTools("test-token");

expect(mockCreateMCPClient).toHaveBeenCalledWith({
transport: {
type: "http",
url: "https://test.vercel.app/mcp",
headers: {
Authorization: "Bearer test-token",
},
},
});
});

it("returns tools from MCP client", async () => {
const result = await getMcpTools("test-token");

expect(result).toEqual(mockTools);
});

it("passes different auth tokens correctly", async () => {
await getMcpTools("different-token");

expect(mockCreateMCPClient).toHaveBeenCalledWith({
transport: {
type: "http",
url: "https://test.vercel.app/mcp",
headers: {
Authorization: "Bearer different-token",
},
},
});
});
});
23 changes: 23 additions & 0 deletions lib/mcp/getMcpTools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { ToolSet } from "ai";
import { experimental_createMCPClient as createMCPClient } from "@ai-sdk/mcp";
import { getBaseUrl } from "@/lib/networking/getBaseUrl";

/**
* Fetches MCP tools via HTTP transport with authentication.
*
* @param authToken - The auth token to use for MCP endpoint authentication
* @returns The MCP tools as a ToolSet
*/
export async function getMcpTools(authToken: string): Promise<ToolSet> {
const mcpClient = await createMCPClient({
transport: {
type: "http",
url: `${getBaseUrl()}/mcp`,
headers: {
Authorization: `Bearer ${authToken}`,
},
},
});

return (await mcpClient.tools()) as ToolSet;
}
Loading
Loading