diff --git a/src/acp-agent.ts b/src/acp-agent.ts index ef326936..10385823 100644 --- a/src/acp-agent.ts +++ b/src/acp-agent.ts @@ -7,6 +7,8 @@ import { ClientCapabilities, InitializeRequest, InitializeResponse, + LoadSessionRequest, + LoadSessionResponse, ndJsonStream, NewSessionRequest, NewSessionResponse, @@ -172,6 +174,9 @@ export class ClaudeAcpAgent implements Agent { return { protocolVersion: 1, agentCapabilities: { + // TODO: Migrate to session/resume once it's added to the ACP schema. + // See: https://github.com/agentclientprotocol/agent-client-protocol/pull/234 + loadSession: true, promptCapabilities: { image: true, embeddedContext: true, @@ -262,7 +267,9 @@ export class ClaudeAcpAgent implements Agent { cwd: params.cwd, includePartialMessages: true, mcpServers: { ...(userProvidedOptions?.mcpServers || {}), ...mcpServers }, - // Set our own session id + // NOTE: --session-id is not part of the public SDK API. + // We use it to synchronize session IDs between ACP and Claude Code SDK + // for session persistence. This may break in future SDK versions. extraArgs: { ...userProvidedOptions?.extraArgs, "session-id": sessionId }, // If we want bypassPermissions to be an option, we have to allow it here. // But it doesn't work in root mode, so we only activate it if it will work. @@ -413,6 +420,157 @@ export class ClaudeAcpAgent implements Agent { }; } + async loadSession(params: LoadSessionRequest): Promise { + if ( + fs.existsSync(path.resolve(os.homedir(), ".claude.json.backup")) && + !fs.existsSync(path.resolve(os.homedir(), ".claude.json")) + ) { + throw RequestError.authRequired(); + } + + // Use session ID from request instead of generating a new one + const sessionId = params.sessionId; + const input = new Pushable(); + + const mcpServers: Record = {}; + if (Array.isArray(params.mcpServers)) { + for (const server of params.mcpServers) { + if ("type" in server) { + mcpServers[server.name] = { + type: server.type, + url: server.url, + headers: server.headers + ? Object.fromEntries(server.headers.map((e) => [e.name, e.value])) + : undefined, + }; + } else { + mcpServers[server.name] = { + type: "stdio", + command: server.command, + args: server.args, + env: server.env + ? Object.fromEntries(server.env.map((e) => [e.name, e.value])) + : undefined, + }; + } + } + } + + const server = createMcpServer(this, sessionId, this.clientCapabilities); + mcpServers["acp"] = { + type: "sdk", + name: "acp", + instance: server, + }; + + const permissionMode = "default"; + + const options: Options = { + systemPrompt: { type: "preset", preset: "claude_code" }, + settingSources: ["user", "project", "local"], + stderr: (err) => this.logger.error(err), + cwd: params.cwd, + includePartialMessages: true, + mcpServers, + allowDangerouslySkipPermissions: !IS_ROOT, + permissionMode, + canUseTool: this.canUseTool(sessionId), + executable: process.execPath as any, + ...(process.env.CLAUDE_CODE_EXECUTABLE && { + pathToClaudeCodeExecutable: process.env.CLAUDE_CODE_EXECUTABLE, + }), + // Resume the existing session instead of creating new + resume: sessionId, + hooks: { + PostToolUse: [ + { + hooks: [createPostToolUseHook(this.logger)], + }, + ], + }, + }; + + const allowedTools = []; + const disallowedTools = []; + + if (this.clientCapabilities?.fs?.readTextFile) { + allowedTools.push(toolNames.read); + disallowedTools.push("Read"); + } + if (this.clientCapabilities?.fs?.writeTextFile) { + disallowedTools.push("Write", "Edit"); + } + if (this.clientCapabilities?.terminal) { + allowedTools.push(toolNames.bashOutput, toolNames.killShell); + disallowedTools.push("Bash", "BashOutput", "KillShell"); + } + + if (allowedTools.length > 0) { + options.allowedTools = allowedTools; + } + if (disallowedTools.length > 0) { + options.disallowedTools = disallowedTools; + } + + const q = query({ + prompt: input, + options, + }); + + this.sessions[sessionId] = { + query: q, + input: input, + cancelled: false, + permissionMode, + }; + + const availableCommands = await getAvailableSlashCommands(q); + const models = await getAvailableModels(q); + + setTimeout(() => { + this.client.sessionUpdate({ + sessionId, + update: { + sessionUpdate: "available_commands_update", + availableCommands, + }, + }); + }, 0); + + const availableModes = [ + { + id: "default", + name: "Always Ask", + description: "Prompts for permission on first use of each tool", + }, + { + id: "acceptEdits", + name: "Accept Edits", + description: "Automatically accepts file edit permissions for the session", + }, + { + id: "plan", + name: "Plan Mode", + description: "Claude can analyze but not modify files or execute commands", + }, + ]; + if (!IS_ROOT) { + availableModes.push({ + id: "bypassPermissions", + name: "Bypass Permissions", + description: "Skips all permission prompts", + }); + } + + return { + models, + modes: { + currentModeId: permissionMode, + availableModes, + }, + }; + } + async authenticate(_params: AuthenticateRequest): Promise { throw new Error("Method not implemented."); } diff --git a/src/tests/acp-agent.test.ts b/src/tests/acp-agent.test.ts index 0cb2ee42..7db5cd82 100644 --- a/src/tests/acp-agent.test.ts +++ b/src/tests/acp-agent.test.ts @@ -912,3 +912,161 @@ describe("permission requests", () => { } }); }); + +describe.skipIf(!process.env.RUN_INTEGRATION_TESTS)("session/load integration", () => { + class TestClient implements Client { + agent: Agent; + files: Map = new Map(); + receivedText: string = ""; + resolveAvailableCommands: (commands: AvailableCommand[]) => void; + availableCommandsPromise: Promise; + + constructor(agent: Agent) { + this.agent = agent; + this.resolveAvailableCommands = () => {}; + this.availableCommandsPromise = new Promise((resolve) => { + this.resolveAvailableCommands = resolve; + }); + } + + takeReceivedText() { + const text = this.receivedText; + this.receivedText = ""; + return text; + } + + async requestPermission(params: RequestPermissionRequest): Promise { + const optionId = params.options.find((p) => p.kind === "allow_once")!.optionId; + return { outcome: { outcome: "selected", optionId } }; + } + + async sessionUpdate(params: SessionNotification): Promise { + switch (params.update.sessionUpdate) { + case "agent_message_chunk": { + if (params.update.content.type === "text") { + this.receivedText += params.update.content.text; + } + break; + } + case "available_commands_update": + this.resolveAvailableCommands(params.update.availableCommands); + break; + default: + break; + } + } + + async writeTextFile(params: WriteTextFileRequest): Promise { + this.files.set(params.path, params.content); + return {}; + } + + async readTextFile(params: ReadTextFileRequest): Promise { + const content = this.files.get(params.path) ?? ""; + return { content }; + } + } + + function startSubprocess() { + const child = spawn("npm", ["run", "--silent", "dev"], { + stdio: ["pipe", "pipe", "inherit"], + env: process.env, + }); + return child; + } + + function killAndWait(child: ReturnType): Promise { + return new Promise((resolve) => { + if (child.exitCode !== null) { + resolve(); + return; + } + child.on("exit", () => resolve()); + child.kill(); + }); + } + + it("should resume session with loadSession after process restart", async () => { + // Compile first + const valid = spawnSync("tsc", { stdio: "inherit" }); + if (valid.status) { + throw new Error("failed to compile"); + } + + let child1: ReturnType | null = null; + let child2: ReturnType | null = null; + + try { + // Create initial session and establish context + child1 = startSubprocess(); + let client1: TestClient; + const input1 = nodeToWebWritable(child1.stdin!); + const output1 = nodeToWebReadable(child1.stdout!); + const stream1 = ndJsonStream(input1, output1); + const connection1 = new ClientSideConnection((agent) => { + client1 = new TestClient(agent); + return client1; + }, stream1); + + await connection1.initialize({ + protocolVersion: 1, + clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, + }); + + const newSessionResponse = await connection1.newSession({ + cwd: __dirname, + mcpServers: [], + }); + const sessionId = newSessionResponse.sessionId; + + // Send a message to establish context + await connection1.prompt({ + prompt: [{ type: "text", text: "I am storing the code XYZ-789 in this conversation. Reply with exactly: 'Code stored: XYZ-789'" }], + sessionId, + }); + const firstResponse = client1!.takeReceivedText(); + expect(firstResponse).toContain("XYZ-789"); + + // Kill the subprocess (simulating disconnection) + await killAndWait(child1); + child1 = null; + + // Start new subprocess and load the session + child2 = startSubprocess(); + let client2: TestClient; + const input2 = nodeToWebWritable(child2.stdin!); + const output2 = nodeToWebReadable(child2.stdout!); + const stream2 = ndJsonStream(input2, output2); + const connection2 = new ClientSideConnection((agent) => { + client2 = new TestClient(agent); + return client2; + }, stream2); + + await connection2.initialize({ + protocolVersion: 1, + clientCapabilities: { fs: { readTextFile: true, writeTextFile: true } }, + }); + + // Load the same session using original sessionId + const loadResponse = await connection2.loadSession({ + sessionId, + cwd: __dirname, + mcpServers: [], + }); + + expect(loadResponse.modes).toBeDefined(); + + // Verify context is preserved by asking about the stored code + await connection2.prompt({ + prompt: [{ type: "text", text: "What code did I store in this conversation? Reply with just the code." }], + sessionId, + }); + + const response = client2!.takeReceivedText(); + expect(response).toContain("XYZ-789"); + } finally { + if (child1) await killAndWait(child1); + if (child2) await killAndWait(child2); + } + }, 120000); +});