diff --git a/CLAUDE.md b/CLAUDE.md index 1649398e..8aef7bf2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -266,6 +266,8 @@ const { accountId, orgId, authToken } = authResult; ### MCP Tools +**CRITICAL: Never manually extract `accountId` from `extra.authInfo` (e.g. `authInfo?.extra?.accountId`).** Always use `resolveAccountId()` — it handles validation, org-key overrides, and access control in one place. + ```typescript import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; @@ -275,6 +277,14 @@ const { accountId, error } = await resolveAccountId({ authInfo, accountIdOverride: undefined, }); + +if (error) { + return getToolResultError(error); +} + +if (!accountId) { + return getToolResultError("Failed to resolve account ID"); +} ``` This ensures: diff --git a/lib/mcp/tools/index.ts b/lib/mcp/tools/index.ts index 36b462d7..8259ed44 100644 --- a/lib/mcp/tools/index.ts +++ b/lib/mcp/tools/index.ts @@ -19,6 +19,7 @@ import { registerSendEmailTool } from "./registerSendEmailTool"; import { registerAllArtistTools } from "./artists"; import { registerAllChatsTools } from "./chats"; import { registerAllPulseTools } from "./pulse"; +import { registerAllSandboxTools } from "./sandbox"; /** * Registers all MCP tools on the server. @@ -37,6 +38,7 @@ export const registerAllTools = (server: McpServer): void => { registerAllFileTools(server); registerAllImageTools(server); registerAllPulseTools(server); + registerAllSandboxTools(server); registerAllSearchTools(server); registerAllSora2Tools(server); registerAllSpotifyTools(server); diff --git a/lib/mcp/tools/sandbox/__tests__/registerRunSandboxCommandTool.test.ts b/lib/mcp/tools/sandbox/__tests__/registerRunSandboxCommandTool.test.ts new file mode 100644 index 00000000..5cb94fb9 --- /dev/null +++ b/lib/mcp/tools/sandbox/__tests__/registerRunSandboxCommandTool.test.ts @@ -0,0 +1,199 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; + +import { registerRunSandboxCommandTool } from "../registerRunSandboxCommandTool"; + +const mockProcessCreateSandbox = vi.fn(); +const mockResolveAccountId = vi.fn(); + +vi.mock("@/lib/sandbox/processCreateSandbox", () => ({ + processCreateSandbox: (...args: unknown[]) => mockProcessCreateSandbox(...args), +})); + +vi.mock("@/lib/mcp/resolveAccountId", () => ({ + resolveAccountId: (...args: unknown[]) => mockResolveAccountId(...args), +})); + +type ServerRequestHandlerExtra = RequestHandlerExtra; + +/** + * Creates a mock extra object with optional authInfo. + * + * @param authInfo + * @param authInfo.accountId + * @param authInfo.orgId + */ +function createMockExtra(authInfo?: { + accountId?: string; + orgId?: string | null; +}): ServerRequestHandlerExtra { + return { + authInfo: authInfo + ? { + token: "test-token", + scopes: ["mcp:tools"], + clientId: authInfo.accountId, + extra: { + accountId: authInfo.accountId, + orgId: authInfo.orgId ?? null, + }, + } + : undefined, + } as unknown as ServerRequestHandlerExtra; +} + +describe("registerRunSandboxCommandTool", () => { + let mockServer: McpServer; + let registeredHandler: (args: unknown, extra: ServerRequestHandlerExtra) => Promise; + + beforeEach(() => { + vi.clearAllMocks(); + + mockServer = { + registerTool: vi.fn((name, config, handler) => { + registeredHandler = handler; + }), + } as unknown as McpServer; + + registerRunSandboxCommandTool(mockServer); + }); + + it("registers the run_sandbox_command tool", () => { + expect(mockServer.registerTool).toHaveBeenCalledWith( + "run_sandbox_command", + expect.objectContaining({ + description: expect.any(String), + }), + expect.any(Function), + ); + }); + + it("returns error when resolveAccountId returns an error", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: null, + error: "Authentication required. Provide an API key via Authorization: Bearer header, or provide account_id from the system prompt context.", + }); + + const result = await registeredHandler( + { command: "ls" }, + createMockExtra(), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Authentication required"), + }, + ], + }); + }); + + it("returns error when resolveAccountId returns null accountId without error", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: null, + error: null, + }); + + const result = await registeredHandler( + { command: "ls" }, + createMockExtra(), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Failed to resolve account ID"), + }, + ], + }); + }); + + it("calls processCreateSandbox with command and returns success", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockProcessCreateSandbox.mockResolvedValue({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + runId: "run_abc123", + }); + + const result = await registeredHandler( + { command: "npm install", args: ["express"], cwd: "/app" }, + createMockExtra({ accountId: "acc_123" }), + ); + + expect(mockProcessCreateSandbox).toHaveBeenCalledWith({ + accountId: "acc_123", + command: "npm install", + args: ["express"], + cwd: "/app", + }); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"sandboxId":"sbx_123"'), + }, + ], + }); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"runId":"run_abc123"'), + }, + ], + }); + }); + + it("passes authInfo to resolveAccountId", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockProcessCreateSandbox.mockResolvedValue({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + + const extra = createMockExtra({ accountId: "acc_123" }); + await registeredHandler({ command: "ls" }, extra); + + expect(mockResolveAccountId).toHaveBeenCalledWith({ + authInfo: extra.authInfo, + accountIdOverride: undefined, + }); + }); + + it("returns error when processCreateSandbox throws", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockProcessCreateSandbox.mockRejectedValue(new Error("Sandbox creation failed")); + + const result = await registeredHandler( + { command: "ls" }, + createMockExtra({ accountId: "acc_123" }), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Sandbox creation failed"), + }, + ], + }); + }); +}); diff --git a/lib/mcp/tools/sandbox/index.ts b/lib/mcp/tools/sandbox/index.ts new file mode 100644 index 00000000..260489ba --- /dev/null +++ b/lib/mcp/tools/sandbox/index.ts @@ -0,0 +1,11 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { registerRunSandboxCommandTool } from "./registerRunSandboxCommandTool"; + +/** + * Registers all sandbox-related MCP tools on the server. + * + * @param server - The MCP server instance to register tools on. + */ +export const registerAllSandboxTools = (server: McpServer): void => { + registerRunSandboxCommandTool(server); +}; diff --git a/lib/mcp/tools/sandbox/registerRunSandboxCommandTool.ts b/lib/mcp/tools/sandbox/registerRunSandboxCommandTool.ts new file mode 100644 index 00000000..f84d4853 --- /dev/null +++ b/lib/mcp/tools/sandbox/registerRunSandboxCommandTool.ts @@ -0,0 +1,61 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; +import { z } from "zod"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; +import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; +import { getToolResultError } from "@/lib/mcp/getToolResultError"; +import { processCreateSandbox } from "@/lib/sandbox/processCreateSandbox"; + +const runSandboxCommandSchema = z.object({ + command: z.string().describe("The command to run in the sandbox."), + args: z.array(z.string()).optional().describe("Arguments for the command."), + cwd: z.string().optional().describe("Working directory for the command."), +}); + +/** + * Registers the "run_sandbox_command" tool on the MCP server. + * Creates a sandbox and runs a command in it. + * + * @param server - The MCP server instance to register the tool on. + */ +export function registerRunSandboxCommandTool(server: McpServer): void { + server.registerTool( + "run_sandbox_command", + { + description: + "Create a sandbox and run a command in it. Returns the sandbox ID and a run ID to track progress.", + inputSchema: runSandboxCommandSchema, + }, + async (args, extra: RequestHandlerExtra) => { + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: undefined, + }); + + if (error) { + return getToolResultError(error); + } + + if (!accountId) { + return getToolResultError("Failed to resolve account ID"); + } + + try { + const result = await processCreateSandbox({ + accountId, + command: args.command, + args: args.args, + cwd: args.cwd, + }); + + return getToolResultSuccess(result); + } catch (error) { + const message = error instanceof Error ? error.message : "Failed to create sandbox"; + return getToolResultError(message); + } + }, + ); +} diff --git a/lib/mcp/tools/tasks/__tests__/registerGetTaskRunStatusTool.test.ts b/lib/mcp/tools/tasks/__tests__/registerGetTaskRunStatusTool.test.ts new file mode 100644 index 00000000..8557432f --- /dev/null +++ b/lib/mcp/tools/tasks/__tests__/registerGetTaskRunStatusTool.test.ts @@ -0,0 +1,215 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; + +import { registerGetTaskRunStatusTool } from "../registerGetTaskRunStatusTool"; + +const mockRetrieveTaskRun = vi.fn(); +const mockResolveAccountId = vi.fn(); + +vi.mock("@/lib/trigger/retrieveTaskRun", () => ({ + retrieveTaskRun: (...args: unknown[]) => mockRetrieveTaskRun(...args), +})); + +vi.mock("@/lib/mcp/resolveAccountId", () => ({ + resolveAccountId: (...args: unknown[]) => mockResolveAccountId(...args), +})); + +type ServerRequestHandlerExtra = RequestHandlerExtra; + +/** + * Creates a mock extra object with optional authInfo. + * + * @param authInfo + * @param authInfo.accountId + * @param authInfo.orgId + */ +function createMockExtra(authInfo?: { + accountId?: string; + orgId?: string | null; +}): ServerRequestHandlerExtra { + return { + authInfo: authInfo + ? { + token: "test-token", + scopes: ["mcp:tools"], + clientId: authInfo.accountId, + extra: { + accountId: authInfo.accountId, + orgId: authInfo.orgId ?? null, + }, + } + : undefined, + } as unknown as ServerRequestHandlerExtra; +} + +describe("registerGetTaskRunStatusTool", () => { + let mockServer: McpServer; + let registeredHandler: (args: unknown, extra: ServerRequestHandlerExtra) => Promise; + + beforeEach(() => { + vi.clearAllMocks(); + + mockServer = { + registerTool: vi.fn((name, config, handler) => { + registeredHandler = handler; + }), + } as unknown as McpServer; + + registerGetTaskRunStatusTool(mockServer); + }); + + it("registers the get_task_run_status tool", () => { + expect(mockServer.registerTool).toHaveBeenCalledWith( + "get_task_run_status", + expect.objectContaining({ + description: expect.any(String), + }), + expect.any(Function), + ); + }); + + it("returns error when resolveAccountId returns an error", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: null, + error: "Authentication required. Provide an API key via Authorization: Bearer header, or provide account_id from the system prompt context.", + }); + + const result = await registeredHandler( + { runId: "run_123" }, + createMockExtra(), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Authentication required"), + }, + ], + }); + }); + + it("returns error when resolveAccountId returns null accountId without error", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: null, + error: null, + }); + + const result = await registeredHandler( + { runId: "run_123" }, + createMockExtra(), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Failed to resolve account ID"), + }, + ], + }); + }); + + it("passes authInfo to resolveAccountId", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockRetrieveTaskRun.mockResolvedValue({ + status: "complete", + data: null, + metadata: null, + taskIdentifier: "test-task", + createdAt: "2024-01-01T00:00:00.000Z", + startedAt: null, + finishedAt: null, + durationMs: null, + }); + + const extra = createMockExtra({ accountId: "acc_123" }); + await registeredHandler({ runId: "run_123" }, extra); + + expect(mockResolveAccountId).toHaveBeenCalledWith({ + authInfo: extra.authInfo, + accountIdOverride: undefined, + }); + }); + + it("returns task run status on success", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockRetrieveTaskRun.mockResolvedValue({ + status: "complete", + data: { output: "done" }, + metadata: { logs: ["step 1", "step 2"] }, + taskIdentifier: "run-sandbox-command", + createdAt: "2024-01-01T00:00:00.000Z", + startedAt: "2024-01-01T00:00:01.000Z", + finishedAt: "2024-01-01T00:00:05.000Z", + durationMs: 4000, + }); + + const result = await registeredHandler( + { runId: "run_123" }, + createMockExtra({ accountId: "acc_123" }), + ); + + expect(mockRetrieveTaskRun).toHaveBeenCalledWith("run_123"); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"status":"complete"'), + }, + ], + }); + }); + + it("returns error when task run is not found", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockRetrieveTaskRun.mockResolvedValue(null); + + const result = await registeredHandler( + { runId: "run_nonexistent" }, + createMockExtra({ accountId: "acc_123" }), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("not found"), + }, + ], + }); + }); + + it("returns error when retrieveTaskRun throws", async () => { + mockResolveAccountId.mockResolvedValue({ + accountId: "acc_123", + error: null, + }); + mockRetrieveTaskRun.mockRejectedValue(new Error("Trigger API error")); + + const result = await registeredHandler( + { runId: "run_123" }, + createMockExtra({ accountId: "acc_123" }), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Trigger API error"), + }, + ], + }); + }); +}); diff --git a/lib/mcp/tools/tasks/index.ts b/lib/mcp/tools/tasks/index.ts index 32052eef..fa1d1870 100644 --- a/lib/mcp/tools/tasks/index.ts +++ b/lib/mcp/tools/tasks/index.ts @@ -3,6 +3,7 @@ import { registerCreateTaskTool } from "./registerCreateTaskTool"; import { registerGetTasksTool } from "./registerGetTasksTool"; import { registerUpdateTaskTool } from "./registerUpdateTaskTool"; import { registerDeleteTaskTool } from "./registerDeleteTaskTool"; +import { registerGetTaskRunStatusTool } from "./registerGetTaskRunStatusTool"; /** * Registers all task-related MCP tools on the server. @@ -14,4 +15,5 @@ export const registerAllTaskTools = (server: McpServer): void => { registerGetTasksTool(server); registerUpdateTaskTool(server); registerDeleteTaskTool(server); + registerGetTaskRunStatusTool(server); }; diff --git a/lib/mcp/tools/tasks/registerGetTaskRunStatusTool.ts b/lib/mcp/tools/tasks/registerGetTaskRunStatusTool.ts new file mode 100644 index 00000000..83014c21 --- /dev/null +++ b/lib/mcp/tools/tasks/registerGetTaskRunStatusTool.ts @@ -0,0 +1,58 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; +import { z } from "zod"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; +import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; +import { getToolResultError } from "@/lib/mcp/getToolResultError"; +import { retrieveTaskRun } from "@/lib/trigger/retrieveTaskRun"; + +const getTaskRunStatusSchema = z.object({ + runId: z.string().describe("The unique identifier of the task run to check."), +}); + +/** + * Registers the "get_task_run_status" tool on the MCP server. + * Retrieves the status of a Trigger.dev task run. + * + * @param server - The MCP server instance to register the tool on. + */ +export function registerGetTaskRunStatusTool(server: McpServer): void { + server.registerTool( + "get_task_run_status", + { + description: + "Get the status of a task run by its run ID. Returns status, metadata, logs, and timestamps.", + inputSchema: getTaskRunStatusSchema, + }, + async (args, extra: RequestHandlerExtra) => { + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: undefined, + }); + + if (error) { + return getToolResultError(error); + } + + if (!accountId) { + return getToolResultError("Failed to resolve account ID"); + } + + try { + const result = await retrieveTaskRun(args.runId); + + if (!result) { + return getToolResultError(`Task run with ID "${args.runId}" not found.`); + } + + return getToolResultSuccess(result); + } catch (error) { + const message = error instanceof Error ? error.message : "Failed to retrieve task run"; + return getToolResultError(message); + } + }, + ); +} diff --git a/lib/sandbox/__tests__/processCreateSandbox.test.ts b/lib/sandbox/__tests__/processCreateSandbox.test.ts new file mode 100644 index 00000000..296e29cc --- /dev/null +++ b/lib/sandbox/__tests__/processCreateSandbox.test.ts @@ -0,0 +1,224 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +import { processCreateSandbox } from "../processCreateSandbox"; +import { createSandbox } from "@/lib/sandbox/createSandbox"; +import { insertAccountSandbox } from "@/lib/supabase/account_sandboxes/insertAccountSandbox"; +import { triggerRunSandboxCommand } from "@/lib/trigger/triggerRunSandboxCommand"; +import { selectAccountSnapshots } from "@/lib/supabase/account_snapshots/selectAccountSnapshots"; + +vi.mock("@/lib/sandbox/createSandbox", () => ({ + createSandbox: vi.fn(), +})); + +vi.mock("@/lib/supabase/account_sandboxes/insertAccountSandbox", () => ({ + insertAccountSandbox: vi.fn(), +})); + +vi.mock("@/lib/trigger/triggerRunSandboxCommand", () => ({ + triggerRunSandboxCommand: vi.fn(), +})); + +vi.mock("@/lib/supabase/account_snapshots/selectAccountSnapshots", () => ({ + selectAccountSnapshots: vi.fn(), +})); + +describe("processCreateSandbox", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("creates sandbox without command and returns result without runId", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([]); + vi.mocked(createSandbox).mockResolvedValue({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + vi.mocked(insertAccountSandbox).mockResolvedValue({ + data: { + id: "record_123", + account_id: "acc_123", + sandbox_id: "sbx_123", + created_at: "2024-01-01T00:00:00.000Z", + }, + error: null, + }); + + const result = await processCreateSandbox({ accountId: "acc_123" }); + + expect(result).toEqual({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + expect(triggerRunSandboxCommand).not.toHaveBeenCalled(); + }); + + it("creates sandbox with command and returns result with runId", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([]); + vi.mocked(createSandbox).mockResolvedValue({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + vi.mocked(insertAccountSandbox).mockResolvedValue({ + data: { + id: "record_123", + account_id: "acc_123", + sandbox_id: "sbx_123", + created_at: "2024-01-01T00:00:00.000Z", + }, + error: null, + }); + vi.mocked(triggerRunSandboxCommand).mockResolvedValue({ + id: "run_abc123", + }); + + const result = await processCreateSandbox({ + accountId: "acc_123", + command: "ls", + args: ["-la"], + cwd: "/home", + }); + + expect(result).toEqual({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + runId: "run_abc123", + }); + expect(triggerRunSandboxCommand).toHaveBeenCalledWith({ + command: "ls", + args: ["-la"], + cwd: "/home", + sandboxId: "sbx_123", + accountId: "acc_123", + }); + }); + + it("uses snapshot when account has one", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([ + { + id: "snap_record_123", + account_id: "acc_123", + snapshot_id: "snap_xyz", + created_at: "2024-01-01T00:00:00.000Z", + }, + ]); + vi.mocked(createSandbox).mockResolvedValue({ + sandboxId: "sbx_456", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + vi.mocked(insertAccountSandbox).mockResolvedValue({ + data: { + id: "record_123", + account_id: "acc_123", + sandbox_id: "sbx_456", + created_at: "2024-01-01T00:00:00.000Z", + }, + error: null, + }); + + await processCreateSandbox({ accountId: "acc_123" }); + + expect(createSandbox).toHaveBeenCalledWith({ + source: { type: "snapshot", snapshotId: "snap_xyz" }, + }); + }); + + it("calls createSandbox with empty params when no snapshot", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([]); + vi.mocked(createSandbox).mockResolvedValue({ + sandboxId: "sbx_456", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + vi.mocked(insertAccountSandbox).mockResolvedValue({ + data: { + id: "record_123", + account_id: "acc_123", + sandbox_id: "sbx_456", + created_at: "2024-01-01T00:00:00.000Z", + }, + error: null, + }); + + await processCreateSandbox({ accountId: "acc_123" }); + + expect(createSandbox).toHaveBeenCalledWith({}); + }); + + it("inserts account_sandbox record", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([]); + vi.mocked(createSandbox).mockResolvedValue({ + sandboxId: "sbx_789", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + vi.mocked(insertAccountSandbox).mockResolvedValue({ + data: { + id: "record_123", + account_id: "acc_123", + sandbox_id: "sbx_789", + created_at: "2024-01-01T00:00:00.000Z", + }, + error: null, + }); + + await processCreateSandbox({ accountId: "acc_123" }); + + expect(insertAccountSandbox).toHaveBeenCalledWith({ + account_id: "acc_123", + sandbox_id: "sbx_789", + }); + }); + + it("throws when createSandbox fails", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([]); + vi.mocked(createSandbox).mockRejectedValue(new Error("Sandbox creation failed")); + + await expect(processCreateSandbox({ accountId: "acc_123" })).rejects.toThrow( + "Sandbox creation failed", + ); + }); + + it("returns result without runId when triggerRunSandboxCommand fails", async () => { + vi.mocked(selectAccountSnapshots).mockResolvedValue([]); + vi.mocked(createSandbox).mockResolvedValue({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + vi.mocked(insertAccountSandbox).mockResolvedValue({ + data: { + id: "record_123", + account_id: "acc_123", + sandbox_id: "sbx_123", + created_at: "2024-01-01T00:00:00.000Z", + }, + error: null, + }); + vi.mocked(triggerRunSandboxCommand).mockRejectedValue(new Error("Task trigger failed")); + + const result = await processCreateSandbox({ + accountId: "acc_123", + command: "ls", + }); + + expect(result).toEqual({ + sandboxId: "sbx_123", + sandboxStatus: "running", + timeout: 600000, + createdAt: "2024-01-01T00:00:00.000Z", + }); + }); +}); diff --git a/lib/sandbox/createSandboxPostHandler.ts b/lib/sandbox/createSandboxPostHandler.ts index 7375061e..a4f4b414 100644 --- a/lib/sandbox/createSandboxPostHandler.ts +++ b/lib/sandbox/createSandboxPostHandler.ts @@ -1,12 +1,8 @@ import type { NextRequest } from "next/server"; import { NextResponse } from "next/server"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; -import { createSandbox } from "@/lib/sandbox/createSandbox"; import { validateSandboxBody } from "@/lib/sandbox/validateSandboxBody"; -import { insertAccountSandbox } from "@/lib/supabase/account_sandboxes/insertAccountSandbox"; -import { triggerRunSandboxCommand } from "@/lib/trigger/triggerRunSandboxCommand"; - -import { selectAccountSnapshots } from "@/lib/supabase/account_snapshots/selectAccountSnapshots"; +import { processCreateSandbox } from "@/lib/sandbox/processCreateSandbox"; /** * Handler for POST /api/sandboxes. @@ -27,47 +23,12 @@ export async function createSandboxPostHandler(request: NextRequest): Promise; +type ProcessCreateSandboxResult = SandboxCreatedResponse & { runId?: string }; + +/** + * Shared domain logic for creating a sandbox and optionally running a command. + * Used by both POST /api/sandboxes handler and the run_sandbox_command MCP tool. + * + * @param input - The sandbox creation parameters + * @returns The sandbox creation result with optional runId + */ +export async function processCreateSandbox( + input: ProcessCreateSandboxInput, +): Promise { + const { accountId, command, args, cwd } = input; + + // Get account's most recent snapshot if available + const accountSnapshots = await selectAccountSnapshots(accountId); + const snapshotId = accountSnapshots[0]?.snapshot_id; + + // Create sandbox (from snapshot if valid, otherwise fresh) + const result = await createSandbox( + snapshotId ? { source: { type: "snapshot", snapshotId } } : {}, + ); + + await insertAccountSandbox({ + account_id: accountId, + sandbox_id: result.sandboxId, + }); + + // Trigger the command execution task if a command was provided + let runId: string | undefined; + if (command) { + try { + const handle = await triggerRunSandboxCommand({ + command, + args, + cwd, + sandboxId: result.sandboxId, + accountId, + }); + runId = handle.id; + } catch (triggerError) { + console.error("Failed to trigger run-sandbox-command task:", triggerError); + runId = undefined; + } + } + + return { + ...result, + ...(runId && { runId }), + }; +}