Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions lib/coding-agent/__tests__/handleMergeSuccess.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import { describe, it, expect, vi, beforeEach } from "vitest";

const mockDeletePRState = vi.fn();
vi.mock("../prState", () => ({
deleteCodingAgentPRState: (...args: unknown[]) => mockDeletePRState(...args),
}));

const mockUpsertAccountSnapshot = vi.fn();
vi.mock("@/lib/supabase/account_snapshots/upsertAccountSnapshot", () => ({
upsertAccountSnapshot: (...args: unknown[]) => mockUpsertAccountSnapshot(...args),
}));

const { handleMergeSuccess } = await import("../handleMergeSuccess");

beforeEach(() => {
vi.clearAllMocks();
mockUpsertAccountSnapshot.mockResolvedValue({ data: {}, error: null });
});

describe("handleMergeSuccess", () => {
it("deletes PR state and persists snapshot", async () => {
await handleMergeSuccess({
status: "pr_created",
branch: "agent/fix-bug",
snapshotId: "snap_abc123",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
});

expect(mockDeletePRState).toHaveBeenCalledWith("recoupable/api", "agent/fix-bug");
expect(mockUpsertAccountSnapshot).toHaveBeenCalledWith({
account_id: "04e3aba9-c130-4fb8-8b92-34e95d43e66b",
snapshot_id: "snap_abc123",
expires_at: expect.any(String),
});
});

it("deletes PR state for all repos when PRs span multiple repos", async () => {
await handleMergeSuccess({
status: "pr_created",
branch: "agent/fix-bug",
prs: [
{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" },
{ repo: "recoupable/chat", number: 10, url: "url", baseBranch: "test" },
{ repo: "recoupable/api", number: 43, url: "url", baseBranch: "test" },
],
});

expect(mockDeletePRState).toHaveBeenCalledTimes(2);
expect(mockDeletePRState).toHaveBeenCalledWith("recoupable/api", "agent/fix-bug");
expect(mockDeletePRState).toHaveBeenCalledWith("recoupable/chat", "agent/fix-bug");
});

it("skips snapshot persistence when snapshotId is not in state", async () => {
await handleMergeSuccess({
status: "pr_created",
branch: "agent/fix-bug",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
});

expect(mockDeletePRState).toHaveBeenCalledWith("recoupable/api", "agent/fix-bug");
expect(mockUpsertAccountSnapshot).not.toHaveBeenCalled();
});

it("logs error but does not throw when snapshot persistence fails", async () => {
mockUpsertAccountSnapshot.mockResolvedValue({
data: null,
error: { message: "db error", code: "500" },
});
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {});

await handleMergeSuccess({
status: "pr_created",
branch: "agent/fix-bug",
snapshotId: "snap_abc123",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
});

expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining("failed to persist snapshot"),
expect.anything(),
);
consoleSpy.mockRestore();
});

it("does not throw when deleteCodingAgentPRState throws", async () => {
mockDeletePRState.mockRejectedValue(new Error("Redis connection failed"));
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {});

await expect(
handleMergeSuccess({
status: "pr_created",
branch: "agent/fix-bug",
snapshotId: "snap_abc123",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
}),
).resolves.toBeUndefined();

expect(consoleSpy).toHaveBeenCalledWith(
"[coding-agent] post-merge cleanup failed:",
expect.any(Error),
);
consoleSpy.mockRestore();
});

it("skips PR state cleanup when branch is missing", async () => {
await handleMergeSuccess({
status: "pr_created",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
});

expect(mockDeletePRState).not.toHaveBeenCalled();
});
});
47 changes: 42 additions & 5 deletions lib/coding-agent/__tests__/onMergeAction.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ import { describe, it, expect, vi, beforeEach } from "vitest";

global.fetch = vi.fn();

const mockDeletePRState = vi.fn();
vi.mock("../prState", () => ({
deleteCodingAgentPRState: (...args: unknown[]) => mockDeletePRState(...args),
const mockHandleMergeSuccess = vi.fn();
vi.mock("../handleMergeSuccess", () => ({
handleMergeSuccess: (...args: unknown[]) => mockHandleMergeSuccess(...args),
}));

const { registerOnMergeAction } = await import("../handlers/onMergeAction");

beforeEach(() => {
vi.clearAllMocks();
process.env.GITHUB_TOKEN = "ghp_test";
mockHandleMergeSuccess.mockResolvedValue(undefined);
});

/**
Expand All @@ -30,7 +31,7 @@ describe("registerOnMergeAction", () => {
expect(bot.onAction).toHaveBeenCalledWith("merge_all_prs", expect.any(Function));
});

it("squash-merges PRs, cleans up shared state, and posts results", async () => {
it("squash-merges PRs, calls handleMergeSuccess, and posts results", async () => {
vi.mocked(fetch).mockResolvedValue({ ok: true } as Response);

const bot = createMockBot();
Expand All @@ -42,6 +43,7 @@ describe("registerOnMergeAction", () => {
status: "pr_created",
prompt: "fix bug",
branch: "agent/fix-bug",
snapshotId: "snap_abc123",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
}),
post: vi.fn(),
Expand All @@ -55,10 +57,44 @@ describe("registerOnMergeAction", () => {
expect.objectContaining({ method: "PUT" }),
);
expect(mockThread.setState).toHaveBeenCalledWith({ status: "merged" });
expect(mockDeletePRState).toHaveBeenCalledWith("recoupable/api", "agent/fix-bug");
expect(mockHandleMergeSuccess).toHaveBeenCalledWith(
expect.objectContaining({ branch: "agent/fix-bug", snapshotId: "snap_abc123" }),
);
expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("merged"));
});

it("does not call handleMergeSuccess when a merge fails", async () => {
vi.mocked(fetch).mockResolvedValue({
ok: false,
status: 409,
text: () => Promise.resolve(JSON.stringify({ message: "merge conflict" })),
} as unknown as Response);

const bot = createMockBot();
registerOnMergeAction(bot);
const handler = bot.onAction.mock.calls[0][1];
const consoleSpy = vi.spyOn(console, "error").mockImplementation(() => {});

const mockThread = {
state: Promise.resolve({
status: "pr_created",
prompt: "fix bug",
branch: "agent/fix-bug",
snapshotId: "snap_abc123",
prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }],
}),
post: vi.fn(),
setState: vi.fn(),
};

await handler({ thread: mockThread });

expect(mockHandleMergeSuccess).not.toHaveBeenCalled();
expect(mockThread.setState).toHaveBeenCalledWith({ status: "pr_created" });
expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("failed"));
consoleSpy.mockRestore();
});

it("posts no PRs message when state has no PRs", async () => {
const bot = createMockBot();
registerOnMergeAction(bot);
Expand All @@ -74,5 +110,6 @@ describe("registerOnMergeAction", () => {

expect(mockThread.post).toHaveBeenCalledWith("No PRs to merge.");
expect(fetch).not.toHaveBeenCalled();
expect(mockHandleMergeSuccess).not.toHaveBeenCalled();
});
});
35 changes: 35 additions & 0 deletions lib/coding-agent/handleMergeSuccess.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { deleteCodingAgentPRState } from "./prState";
import { upsertAccountSnapshot } from "@/lib/supabase/account_snapshots/upsertAccountSnapshot";
import { RECOUP_ORG_ID, SNAPSHOT_EXPIRY_MS } from "@/lib/const";
import type { CodingAgentThreadState } from "./types";

/**
* Handles post-merge cleanup after all PRs merged successfully.
* Deletes the shared PR state keys for all repos and persists the latest
* snapshot via upsertAccountSnapshot.
*/
export async function handleMergeSuccess(state: CodingAgentThreadState): Promise<void> {
try {
if (state.branch && state.prs?.length) {
const repos = [...new Set(state.prs.map(pr => pr.repo))];
await Promise.all(repos.map(repo => deleteCodingAgentPRState(repo, state.branch!)));
}

if (state.snapshotId) {
const snapshotResult = await upsertAccountSnapshot({
account_id: RECOUP_ORG_ID,
snapshot_id: state.snapshotId,
expires_at: new Date(Date.now() + SNAPSHOT_EXPIRY_MS).toISOString(),
});

if (snapshotResult.error) {
console.error(
`[coding-agent] failed to persist snapshot for ${RECOUP_ORG_ID}:`,
snapshotResult.error,
);
}
}
} catch (error) {
console.error("[coding-agent] post-merge cleanup failed:", error);
}
}
15 changes: 10 additions & 5 deletions lib/coding-agent/handlers/onMergeAction.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import type { CodingAgentBot } from "../bot";
import { deleteCodingAgentPRState } from "../prState";
import type { CodingAgentThreadState } from "../types";
import { handleMergeSuccess } from "../handleMergeSuccess";

/**
* Registers the "Merge All PRs" button action handler on the bot.
* Squash-merges each PR via the GitHub API.
* Squash-merges each PR via the GitHub API, then delegates to
* handleMergeSuccess to clean up PR state and persist the latest snapshot.
*
* @param bot
*/
Expand Down Expand Up @@ -51,10 +52,14 @@ export function registerOnMergeAction(bot: CodingAgentBot) {
}
}

await thread.setState({ status: "merged" });
if (state.branch && state.prs?.[0]?.repo) {
await deleteCodingAgentPRState(state.prs[0].repo, state.branch);
const allMerged = results.every(r => r.endsWith("merged"));

// On failure, revert to pr_created so handleFeedback still accepts replies
await thread.setState({ status: allMerged ? "merged" : "pr_created" });
if (allMerged) {
await handleMergeSuccess(state);
}

await thread.post(`Merge results:\n${results.map(r => `- ${r}`).join("\n")}`);
});
}
3 changes: 3 additions & 0 deletions lib/const.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ export const RECOUP_API_KEY = process.env.RECOUP_API_KEY || "";
export const FLAMINGO_GENERATE_URL =
"https://sidney-78147--music-flamingo-musicflamingo-generate.modal.run";

/** Snapshot expiration duration (7 days) */
export const SNAPSHOT_EXPIRY_MS = 7 * 24 * 60 * 60 * 1000;

// EVALS
export const EVAL_ACCOUNT_ID = "fb678396-a68f-4294-ae50-b8cacf9ce77b";
export const EVAL_ACCESS_TOKEN = process.env.EVAL_ACCESS_TOKEN || "";
Expand Down
3 changes: 2 additions & 1 deletion lib/sandbox/updateSnapshotPatchHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { NextResponse } from "next/server";
import { getCorsHeaders } from "@/lib/networking/getCorsHeaders";
import { validateSnapshotPatchBody } from "@/lib/sandbox/validateSnapshotPatchBody";
import { upsertAccountSnapshot } from "@/lib/supabase/account_snapshots/upsertAccountSnapshot";
import { SNAPSHOT_EXPIRY_MS } from "@/lib/const";
import { selectAccountSnapshots } from "@/lib/supabase/account_snapshots/selectAccountSnapshots";

/**
Expand Down Expand Up @@ -30,7 +31,7 @@ export async function updateSnapshotPatchHandler(request: NextRequest): Promise<
account_id: validated.accountId,
...(validated.snapshotId && {
snapshot_id: validated.snapshotId,
expires_at: new Date(Date.now() + 365 * 24 * 60 * 60 * 1000).toISOString(),
expires_at: new Date(Date.now() + SNAPSHOT_EXPIRY_MS).toISOString(),
}),
...(validated.githubRepo && { github_repo: validated.githubRepo }),
});
Expand Down