Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions src/acp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
PROTOCOL_VERSION,
ndJsonStream,
} from "./acp.js";
import type { AnyMessage } from "./acp.js";

describe("Connection", () => {
let clientToAgent: TransformStream<Uint8Array, Uint8Array>;
Expand Down Expand Up @@ -971,6 +972,107 @@ describe("Connection", () => {
expect(closeLog).toContain("client connection closed (signal)");
});

class MinimalTestClient implements Client {
async writeTextFile(
_: WriteTextFileRequest,
): Promise<WriteTextFileResponse> {
return {};
}
async readTextFile(_: ReadTextFileRequest): Promise<ReadTextFileResponse> {
return { content: "test" };
}
async requestPermission(
_: RequestPermissionRequest,
): Promise<RequestPermissionResponse> {
return {
outcome: {
outcome: "selected",
optionId: "allow",
},
};
}
async sessionUpdate(_: SessionNotification): Promise<void> {
// no-op
}
}

it("rejects pending requests when the stream errors", async () => {
let readableController!: ReadableStreamDefaultController<AnyMessage>;

const connection = new ClientSideConnection(() => new MinimalTestClient(), {
readable: new ReadableStream<AnyMessage>({
start(controller) {
readableController = controller;
},
}),
writable: new WritableStream<AnyMessage>({
async write() {
// no-op
},
}),
});

const requestPromise = connection.newSession({
cwd: "/test",
mcpServers: [],
});
const error = new Error("stream exploded");

readableController.error(error);

await expect(requestPromise).rejects.toThrow("stream exploded");
await expect(connection.closed).resolves.toBeUndefined();
expect(connection.signal.aborted).toBe(true);
});

it("rejects pending requests when the writable stream errors", async () => {
const writeError = new Error("write failed");

const connection = new ClientSideConnection(() => new MinimalTestClient(), {
readable: new ReadableStream<AnyMessage>({
// Never produces messages; stays open.
start() {},
}),
writable: new WritableStream<AnyMessage>({
async write() {
throw writeError;
},
}),
});

const requestPromise = connection.newSession({
cwd: "/test",
mcpServers: [],
});

await expect(requestPromise).rejects.toThrow("write failed");
await expect(connection.closed).resolves.toBeUndefined();
expect(connection.signal.aborted).toBe(true);
});

it("rejects requests issued after the connection is closed", async () => {
const connection = new ClientSideConnection(() => new MinimalTestClient(), {
readable: new ReadableStream<AnyMessage>({
start(controller) {
// Close the readable stream immediately so the connection closes.
controller.close();
},
}),
writable: new WritableStream<AnyMessage>({
async write() {
// no-op
},
}),
});

await connection.closed;
expect(connection.signal.aborted).toBe(true);

await expect(
connection.newSession({ cwd: "/test", mcpServers: [] }),
).rejects.toThrow("ACP connection closed");
});

it("supports removing signal event listeners", async () => {
const closeLog: string[] = [];

Expand Down
106 changes: 71 additions & 35 deletions src/acp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ import type {
AnyResponse,
Result,
ErrorResponse,
PendingResponse,
RequestHandler,
NotificationHandler,
} from "./jsonrpc.js";

type ConnectionPendingResponse = {
resolve: (response: unknown) => void;
reject: (error: unknown) => void;
};
Comment on lines +18 to +21
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PendingResponse is protocol-shaped: it only covers JSON-RPC error responses.

ConnectionPendingResponse is connection-shaped: on stream shutdown we may need to reject with a raw transport error before any JSON-RPC envelope exists. Decided to keep that broader type local to acp.ts to avoid widening the shared protocol types.


/**
* An agent-side connection to a client.
*
Expand Down Expand Up @@ -931,7 +935,8 @@ export class ClientSideConnection implements Agent {
export type { AnyMessage } from "./jsonrpc.js";

class Connection {
#pendingResponses: Map<string | number | null, PendingResponse> = new Map();
#pendingResponses: Map<string | number | null, ConnectionPendingResponse> =
new Map();
#nextRequestId: number = 0;
#requestHandler: RequestHandler;
#notificationHandler: NotificationHandler;
Expand All @@ -951,7 +956,7 @@ class Connection {
this.#closedPromise = new Promise((resolve) => {
this.#abortController.signal.addEventListener("abort", () => resolve());
});
this.#receive();
void this.#receive();
}

/**
Expand Down Expand Up @@ -985,44 +990,64 @@ class Connection {
}

async #receive() {
const reader = this.#stream.readable.getReader();
let closeError: unknown = undefined;

try {
while (true) {
const { value: message, done } = await reader.read();
if (done) {
break;
}
if (!message) {
continue;
}
const reader = this.#stream.readable.getReader();
try {
while (!this.#abortController.signal.aborted) {
const { value: message, done } = await reader.read();
if (done) {
break;
}
if (!message) {
continue;
}

try {
this.#processMessage(message);
} catch (err) {
console.error(
"Unexpected error during message processing:",
message,
err,
);
// Only send error response if the message had an id (was a request)
if ("id" in message && message.id !== undefined) {
this.#sendMessage({
jsonrpc: "2.0",
id: message.id,
error: {
code: -32700,
message: "Parse error",
},
});
try {
this.#processMessage(message);
} catch (err) {
console.error(
"Unexpected error during message processing:",
message,
err,
);
// Only send error response if the message had an id (was a request)
if ("id" in message && message.id !== undefined) {
this.#sendMessage({
jsonrpc: "2.0",
id: message.id,
error: {
code: -32700,
message: "Parse error",
},
});
}
}
}
} finally {
reader.releaseLock();
}
} catch (error) {
closeError = error;
} finally {
reader.releaseLock();
this.#abortController.abort();
this.#close(closeError);
}
}

#close(error?: unknown) {
if (this.#abortController.signal.aborted) {
return;
}

const closeError: unknown = error ?? new Error("ACP connection closed");
for (const pendingResponse of this.#pendingResponses.values()) {
pendingResponse.reject(closeError);
}
this.#pendingResponses.clear();
this.#abortController.abort(closeError);
}

async #processMessage(message: AnyMessage) {
if ("method" in message && "id" in message) {
// It's a request
Expand Down Expand Up @@ -1140,7 +1165,8 @@ class Connection {
if ("result" in response) {
pendingResponse.resolve(response.result);
} else if ("error" in response) {
pendingResponse.reject(response.error);
const { code, message, data } = response.error;
pendingResponse.reject(new RequestError(code, message, data));
}
this.#pendingResponses.delete(response.id);
} else {
Expand All @@ -1149,6 +1175,7 @@ class Connection {
}

async sendRequest<Req, Resp>(method: string, params?: Req): Promise<Resp> {
this.#throwIfClosed();
const id = this.#nextRequestId++;
const responsePromise = new Promise((resolve, reject) => {
this.#pendingResponses.set(id, { resolve, reject });
Expand All @@ -1158,9 +1185,19 @@ class Connection {
}

async sendNotification<N>(method: string, params?: N): Promise<void> {
this.#throwIfClosed();
await this.#sendMessage({ jsonrpc: "2.0", method, params });
}

#throwIfClosed() {
if (this.#abortController.signal.aborted) {
throw (
this.#abortController.signal.reason ??
new Error("ACP connection closed")
);
}
}

async #sendMessage(message: AnyMessage) {
this.#writeQueue = this.#writeQueue
.then(async () => {
Expand All @@ -1172,8 +1209,7 @@ class Connection {
}
})
.catch((error) => {
// Continue processing writes on error
console.error("ACP write error:", error);
this.#close(error);
});
return this.#writeQueue;
}
Expand Down
5 changes: 0 additions & 5 deletions src/jsonrpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ export type ErrorResponse = {
data?: unknown;
};

export type PendingResponse = {
resolve: (response: unknown) => void;
reject: (error: ErrorResponse) => void;
};

export type RequestHandler = (
method: string,
params: unknown,
Expand Down