diff --git a/lib/coding-agent/__tests__/buildPRStateKey.test.ts b/lib/coding-agent/__tests__/buildPRStateKey.test.ts new file mode 100644 index 00000000..2e8c6b94 --- /dev/null +++ b/lib/coding-agent/__tests__/buildPRStateKey.test.ts @@ -0,0 +1,10 @@ +import { describe, it, expect } from "vitest"; +import { buildPRStateKey } from "../prState/buildPRStateKey"; + +describe("buildPRStateKey", () => { + it("builds the correct key", () => { + expect(buildPRStateKey("recoupable/api", "agent/fix-bug")).toBe( + "coding-agent:pr:recoupable/api:agent/fix-bug", + ); + }); +}); diff --git a/lib/coding-agent/__tests__/deleteCodingAgentPRState.test.ts b/lib/coding-agent/__tests__/deleteCodingAgentPRState.test.ts new file mode 100644 index 00000000..2015c854 --- /dev/null +++ b/lib/coding-agent/__tests__/deleteCodingAgentPRState.test.ts @@ -0,0 +1,22 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +const mockDel = vi.fn(); + +vi.mock("@/lib/redis/connection", () => ({ + default: { + del: (...args: unknown[]) => mockDel(...args), + }, +})); + +const { deleteCodingAgentPRState } = await import("../prState/deleteCodingAgentPRState"); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("deleteCodingAgentPRState", () => { + it("deletes the key from Redis", async () => { + await deleteCodingAgentPRState("recoupable/api", "agent/fix-bug"); + expect(mockDel).toHaveBeenCalledWith("coding-agent:pr:recoupable/api:agent/fix-bug"); + }); +}); diff --git a/lib/coding-agent/__tests__/getCodingAgentPRState.test.ts b/lib/coding-agent/__tests__/getCodingAgentPRState.test.ts new file mode 100644 index 00000000..ab64107a --- /dev/null +++ b/lib/coding-agent/__tests__/getCodingAgentPRState.test.ts @@ -0,0 +1,45 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +const mockGet = vi.fn(); + +vi.mock("@/lib/redis/connection", () => ({ + default: { + get: (...args: unknown[]) => mockGet(...args), + }, +})); + +const { getCodingAgentPRState } = await import("../prState/getCodingAgentPRState"); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("getCodingAgentPRState", () => { + it("returns null when key does not exist", async () => { + mockGet.mockResolvedValue(null); + const result = await getCodingAgentPRState("recoupable/api", "agent/fix-bug"); + expect(result).toBeNull(); + expect(mockGet).toHaveBeenCalledWith("coding-agent:pr:recoupable/api:agent/fix-bug"); + }); + + it("returns parsed state when key exists", async () => { + const state = { + status: "pr_created", + snapshotId: "snap_abc", + branch: "agent/fix-bug", + repo: "recoupable/api", + prs: [ + { + repo: "recoupable/api", + number: 42, + url: "https://github.com/recoupable/api/pull/42", + baseBranch: "test", + }, + ], + }; + mockGet.mockResolvedValue(JSON.stringify(state)); + + const result = await getCodingAgentPRState("recoupable/api", "agent/fix-bug"); + expect(result).toEqual(state); + }); +}); diff --git a/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts b/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts index 64e30c7b..2a5a9da4 100644 --- a/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts +++ b/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts @@ -22,11 +22,11 @@ vi.mock("chat", () => { const parts = threadId.split(":"); return `${parts[0]}:${parts[1]}`; }), - Card: vi.fn((opts) => ({ type: "card", ...opts })), - CardText: vi.fn((text) => ({ type: "text", text })), - Actions: vi.fn((children) => ({ type: "actions", children })), - Button: vi.fn((opts) => ({ type: "button", ...opts })), - LinkButton: vi.fn((opts) => ({ type: "link-button", ...opts })), + Card: vi.fn(opts => ({ type: "card", ...opts })), + CardText: vi.fn(text => ({ type: "text", text })), + Actions: vi.fn(children => ({ type: "actions", children })), + Button: vi.fn(opts => ({ type: "button", ...opts })), + LinkButton: vi.fn(opts => ({ type: "link-button", ...opts })), }; }); @@ -34,6 +34,11 @@ vi.mock("../bot", () => ({ codingAgentBot: {}, })); +const mockSetPRState = vi.fn(); +vi.mock("../prState", () => ({ + setCodingAgentPRState: (...args: unknown[]) => mockSetPRState(...args), +})); + const { handleCodingAgentCallback } = await import("../handleCodingAgentCallback"); beforeEach(() => { @@ -136,7 +141,14 @@ describe("handleCodingAgentCallback", () => { it("posts updated card with PR buttons for updated status", async () => { mockState = { status: "updating", - prs: [{ repo: "recoupable/api", number: 42, url: "https://github.com/recoupable/api/pull/42", baseBranch: "test" }], + prs: [ + { + repo: "recoupable/api", + number: 42, + url: "https://github.com/recoupable/api/pull/42", + baseBranch: "test", + }, + ], }; const body = { @@ -149,7 +161,9 @@ describe("handleCodingAgentCallback", () => { const response = await handleCodingAgentCallback(request); expect(response.status).toBe(200); - expect(mockSetState).toHaveBeenCalledWith(expect.objectContaining({ status: "pr_created", snapshotId: "snap_new" })); + expect(mockSetState).toHaveBeenCalledWith( + expect.objectContaining({ status: "pr_created", snapshotId: "snap_new" }), + ); expect(mockPost).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); }); }); diff --git a/lib/coding-agent/__tests__/handlePRCreated.test.ts b/lib/coding-agent/__tests__/handlePRCreated.test.ts index 6f80972a..09444996 100644 --- a/lib/coding-agent/__tests__/handlePRCreated.test.ts +++ b/lib/coding-agent/__tests__/handlePRCreated.test.ts @@ -10,11 +10,16 @@ vi.mock("../getThread", () => ({ })); vi.mock("chat", () => ({ - Card: vi.fn((opts) => ({ type: "card", ...opts })), - CardText: vi.fn((text) => ({ type: "text", text })), - Actions: vi.fn((children) => ({ type: "actions", children })), - Button: vi.fn((opts) => ({ type: "button", ...opts })), - LinkButton: vi.fn((opts) => ({ type: "link-button", ...opts })), + Card: vi.fn(opts => ({ type: "card", ...opts })), + CardText: vi.fn(text => ({ type: "text", text })), + Actions: vi.fn(children => ({ type: "actions", children })), + Button: vi.fn(opts => ({ type: "button", ...opts })), + LinkButton: vi.fn(opts => ({ type: "link-button", ...opts })), +})); + +const mockSetPRState = vi.fn(); +vi.mock("../prState", () => ({ + setCodingAgentPRState: (...args: unknown[]) => mockSetPRState(...args), })); describe("handlePRCreated", () => { @@ -26,10 +31,19 @@ describe("handlePRCreated", () => { status: "pr_created", branch: "agent/fix-bug", snapshotId: "snap_abc", - prs: [{ repo: "recoupable/api", number: 42, url: "https://github.com/recoupable/api/pull/42", baseBranch: "test" }], + prs: [ + { + repo: "recoupable/api", + number: 42, + url: "https://github.com/recoupable/api/pull/42", + baseBranch: "test", + }, + ], }); - expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); + expect(mockThread.post).toHaveBeenCalledWith( + expect.objectContaining({ card: expect.anything() }), + ); const { Button } = await import("chat"); expect(Button).toHaveBeenCalledWith( @@ -43,5 +57,16 @@ describe("handlePRCreated", () => { snapshotId: "snap_abc", }), ); + + expect(mockSetPRState).toHaveBeenCalledWith( + "recoupable/api", + "agent/fix-bug", + expect.objectContaining({ + status: "pr_created", + snapshotId: "snap_abc", + branch: "agent/fix-bug", + repo: "recoupable/api", + }), + ); }); }); diff --git a/lib/coding-agent/__tests__/handlers.test.ts b/lib/coding-agent/__tests__/handlers.test.ts index 5cd30f50..b21ebfe2 100644 --- a/lib/coding-agent/__tests__/handlers.test.ts +++ b/lib/coding-agent/__tests__/handlers.test.ts @@ -12,10 +12,15 @@ vi.mock("@/lib/trigger/triggerUpdatePR", () => ({ })); vi.mock("chat", () => ({ - Card: vi.fn((opts) => ({ type: "card", ...opts })), - CardText: vi.fn((text) => ({ type: "text", text })), - Actions: vi.fn((children) => ({ type: "actions", children })), - LinkButton: vi.fn((opts) => ({ type: "link-button", ...opts })), + Card: vi.fn(opts => ({ type: "card", ...opts })), + CardText: vi.fn(text => ({ type: "text", text })), + Actions: vi.fn(children => ({ type: "actions", children })), + LinkButton: vi.fn(opts => ({ type: "link-button", ...opts })), +})); + +vi.mock("../prState", () => ({ + getCodingAgentPRState: vi.fn().mockResolvedValue(null), + setCodingAgentPRState: vi.fn(), })); const { registerOnNewMention } = await import("../handlers/onNewMention"); @@ -61,7 +66,9 @@ describe("registerOnNewMention", () => { expect(mockThread.subscribe).toHaveBeenCalledOnce(); expect(mockTriggerCodingAgent).toHaveBeenCalled(); - expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); + expect(mockThread.post).toHaveBeenCalledWith( + expect.objectContaining({ card: expect.anything() }), + ); expect(mockThread.setState).toHaveBeenCalledWith( expect.objectContaining({ status: "running", @@ -82,7 +89,14 @@ describe("registerOnNewMention", () => { prompt: "original prompt", snapshotId: "snap_abc", branch: "agent/fix-bug", - prs: [{ repo: "recoupable/tasks", number: 56, url: "https://github.com/recoupable/tasks/pull/56", baseBranch: "main" }], + prs: [ + { + repo: "recoupable/tasks", + number: 56, + url: "https://github.com/recoupable/tasks/pull/56", + baseBranch: "main", + }, + ], }), subscribe: vi.fn(), post: vi.fn(), @@ -104,8 +118,53 @@ describe("registerOnNewMention", () => { repo: "recoupable/tasks", }), ); - expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); - expect(mockThread.setState).toHaveBeenCalledWith(expect.objectContaining({ status: "updating" })); + expect(mockThread.post).toHaveBeenCalledWith( + expect.objectContaining({ card: expect.anything() }), + ); + expect(mockThread.setState).toHaveBeenCalledWith( + expect.objectContaining({ status: "updating" }), + ); + }); + + it("resolves PR state from shared key when thread state is null and raw has repo/branch", async () => { + const { getCodingAgentPRState } = await import("../prState"); + vi.mocked(getCodingAgentPRState).mockResolvedValue({ + status: "pr_created", + snapshotId: "snap_abc", + branch: "agent/fix-bug", + repo: "recoupable/api", + prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }], + }); + + const bot = createMockBot(); + registerOnNewMention(bot); + const handler = bot.onNewMention.mock.calls[0][0]; + + const mockThread = { + id: "github:recoupable/api:42", + state: Promise.resolve(null), + subscribe: vi.fn(), + post: vi.fn(), + setState: vi.fn(), + }; + const mockMessage = { + text: "make the button blue", + author: { id: "sweetmantech" }, + raw: { repo: "recoupable/api", branch: "agent/fix-bug" }, + }; + + await handler(mockThread, mockMessage); + + expect(getCodingAgentPRState).toHaveBeenCalledWith("recoupable/api", "agent/fix-bug"); + expect(mockTriggerUpdatePR).toHaveBeenCalledWith( + expect.objectContaining({ + feedback: "make the button blue", + snapshotId: "snap_abc", + branch: "agent/fix-bug", + repo: "recoupable/api", + }), + ); + expect(mockTriggerCodingAgent).not.toHaveBeenCalled(); }); it("tells user to wait when thread is already running", async () => { diff --git a/lib/coding-agent/__tests__/onMergeAction.test.ts b/lib/coding-agent/__tests__/onMergeAction.test.ts index 2b249cf5..f6007f79 100644 --- a/lib/coding-agent/__tests__/onMergeAction.test.ts +++ b/lib/coding-agent/__tests__/onMergeAction.test.ts @@ -2,6 +2,11 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; global.fetch = vi.fn(); +const mockDeletePRState = vi.fn(); +vi.mock("../prState", () => ({ + deleteCodingAgentPRState: (...args: unknown[]) => mockDeletePRState(...args), +})); + const { registerOnMergeAction } = await import("../handlers/onMergeAction"); beforeEach(() => { @@ -9,6 +14,9 @@ beforeEach(() => { process.env.GITHUB_TOKEN = "ghp_test"; }); +/** + * + */ function createMockBot() { return { onAction: vi.fn(), @@ -22,7 +30,7 @@ describe("registerOnMergeAction", () => { expect(bot.onAction).toHaveBeenCalledWith("merge_all_prs", expect.any(Function)); }); - it("squash-merges PRs and posts results", async () => { + it("squash-merges PRs, cleans up shared state, and posts results", async () => { vi.mocked(fetch).mockResolvedValue({ ok: true } as Response); const bot = createMockBot(); @@ -33,6 +41,7 @@ describe("registerOnMergeAction", () => { state: Promise.resolve({ status: "pr_created", prompt: "fix bug", + branch: "agent/fix-bug", prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }], }), post: vi.fn(), @@ -46,6 +55,7 @@ describe("registerOnMergeAction", () => { expect.objectContaining({ method: "PUT" }), ); expect(mockThread.setState).toHaveBeenCalledWith({ status: "merged" }); + expect(mockDeletePRState).toHaveBeenCalledWith("recoupable/api", "agent/fix-bug"); expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("merged")); }); diff --git a/lib/coding-agent/__tests__/onSubscribedMessage.test.ts b/lib/coding-agent/__tests__/onSubscribedMessage.test.ts index 8ba6b565..78a772e9 100644 --- a/lib/coding-agent/__tests__/onSubscribedMessage.test.ts +++ b/lib/coding-agent/__tests__/onSubscribedMessage.test.ts @@ -4,12 +4,19 @@ vi.mock("@/lib/trigger/triggerUpdatePR", () => ({ triggerUpdatePR: vi.fn().mockResolvedValue({ id: "run_456" }), })); +vi.mock("../prState", () => ({ + setCodingAgentPRState: vi.fn(), +})); + const { registerOnSubscribedMessage } = await import("../handlers/onSubscribedMessage"); beforeEach(() => { vi.clearAllMocks(); }); +/** + * + */ function createMockBot() { return { onSubscribedMessage: vi.fn(), @@ -45,8 +52,12 @@ describe("registerOnSubscribedMessage", () => { await handler(mockThread, { text: "make the button blue", author: { userId: "U111" } }); - expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); - expect(mockThread.setState).toHaveBeenCalledWith(expect.objectContaining({ status: "updating" })); + expect(mockThread.post).toHaveBeenCalledWith( + expect.objectContaining({ card: expect.anything() }), + ); + expect(mockThread.setState).toHaveBeenCalledWith( + expect.objectContaining({ status: "updating" }), + ); expect(triggerUpdatePR).toHaveBeenCalledWith( expect.objectContaining({ feedback: "make the button blue", diff --git a/lib/coding-agent/__tests__/resolvePRState.test.ts b/lib/coding-agent/__tests__/resolvePRState.test.ts new file mode 100644 index 00000000..308ef700 --- /dev/null +++ b/lib/coding-agent/__tests__/resolvePRState.test.ts @@ -0,0 +1,89 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +const mockGetPRState = vi.fn(); + +vi.mock("../prState", () => ({ + getCodingAgentPRState: (...args: unknown[]) => mockGetPRState(...args), +})); + +const { resolvePRState } = await import("../resolvePRState"); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +/** + * + * @param state + */ +function createMockThread(state: unknown) { + return { + id: "slack:C123:ts", + get state() { + return Promise.resolve(state); + }, + } as any; +} + +describe("resolvePRState", () => { + it("returns thread state when available", async () => { + const threadState = { + status: "pr_created", + prompt: "fix bug", + branch: "agent/fix", + snapshotId: "snap_1", + prs: [{ repo: "recoupable/api", number: 1, url: "url", baseBranch: "test" }], + }; + const thread = createMockThread(threadState); + + const result = await resolvePRState(thread); + expect(result).toEqual(threadState); + expect(mockGetPRState).not.toHaveBeenCalled(); + }); + + it("falls back to shared PR state when thread state is null", async () => { + const thread = createMockThread(null); + const prState = { + status: "pr_created", + snapshotId: "snap_1", + branch: "agent/fix", + repo: "recoupable/api", + prs: [{ repo: "recoupable/api", number: 1, url: "url", baseBranch: "test" }], + }; + mockGetPRState.mockResolvedValue(prState); + + const result = await resolvePRState(thread, { repo: "recoupable/api", branch: "agent/fix" }); + + expect(mockGetPRState).toHaveBeenCalledWith("recoupable/api", "agent/fix"); + expect(result).toEqual({ + status: "pr_created", + prompt: "", + branch: "agent/fix", + snapshotId: "snap_1", + prs: prState.prs, + }); + }); + + it("returns null when neither thread state nor PR context exists", async () => { + const thread = createMockThread(null); + const result = await resolvePRState(thread); + expect(result).toBeNull(); + }); + + it("returns null when PR context has no match in Redis", async () => { + const thread = createMockThread(null); + mockGetPRState.mockResolvedValue(null); + + const result = await resolvePRState(thread, { repo: "recoupable/api", branch: "agent/fix" }); + expect(result).toBeNull(); + }); + + it("ignores PR context when thread state exists", async () => { + const threadState = { status: "running", prompt: "fix bug" }; + const thread = createMockThread(threadState); + + const result = await resolvePRState(thread, { repo: "recoupable/api", branch: "agent/fix" }); + expect(result).toEqual(threadState); + expect(mockGetPRState).not.toHaveBeenCalled(); + }); +}); diff --git a/lib/coding-agent/__tests__/setCodingAgentPRState.test.ts b/lib/coding-agent/__tests__/setCodingAgentPRState.test.ts new file mode 100644 index 00000000..b7141d1b --- /dev/null +++ b/lib/coding-agent/__tests__/setCodingAgentPRState.test.ts @@ -0,0 +1,34 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +const mockSet = vi.fn(); + +vi.mock("@/lib/redis/connection", () => ({ + default: { + set: (...args: unknown[]) => mockSet(...args), + }, +})); + +const { setCodingAgentPRState } = await import("../prState/setCodingAgentPRState"); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe("setCodingAgentPRState", () => { + it("stores serialized state in Redis", async () => { + const state = { + status: "pr_created" as const, + snapshotId: "snap_abc", + branch: "agent/fix-bug", + repo: "recoupable/api", + prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }], + }; + + await setCodingAgentPRState("recoupable/api", "agent/fix-bug", state); + + expect(mockSet).toHaveBeenCalledWith( + "coding-agent:pr:recoupable/api:agent/fix-bug", + JSON.stringify(state), + ); + }); +}); diff --git a/lib/coding-agent/handleCodingAgentCallback.ts b/lib/coding-agent/handleCodingAgentCallback.ts index 5397419a..22a06a0f 100644 --- a/lib/coding-agent/handleCodingAgentCallback.ts +++ b/lib/coding-agent/handleCodingAgentCallback.ts @@ -4,6 +4,7 @@ import { validateCodingAgentCallback } from "./validateCodingAgentCallback"; import { getThread } from "./getThread"; import { handlePRCreated } from "./handlePRCreated"; import { buildPRCard } from "./buildPRCard"; +import { setCodingAgentPRState } from "./prState"; import type { CodingAgentThreadState } from "./types"; /** @@ -63,6 +64,16 @@ export async function handleCodingAgentCallback(request: Request): Promise `- ${r}`).join("\n")}`); }); } diff --git a/lib/coding-agent/handlers/onNewMention.ts b/lib/coding-agent/handlers/onNewMention.ts index ea020146..63514474 100644 --- a/lib/coding-agent/handlers/onNewMention.ts +++ b/lib/coding-agent/handlers/onNewMention.ts @@ -1,19 +1,27 @@ import type { CodingAgentBot } from "../bot"; import { buildTaskCard } from "../buildTaskCard"; import { triggerCodingAgent } from "@/lib/trigger/triggerCodingAgent"; +import { resolvePRState } from "../resolvePRState"; import { handleFeedback } from "./handleFeedback"; /** * Registers the onNewMention handler on the bot. - * If the thread already has PRs, treats the mention as feedback and - * triggers the update-pr task. Otherwise, starts a new coding agent task. + * If the thread already has PRs (via thread state or shared PR state key), + * treats the mention as feedback and triggers the update-pr task. + * Otherwise, starts a new coding agent task. + * + * For GitHub PR comments, message.meta may contain { repo, branch } to look up + * the shared PR state key when thread state is empty. * * @param bot */ export function registerOnNewMention(bot: CodingAgentBot) { bot.onNewMention(async (thread, message) => { try { - const state = await thread.state; + const raw = message.raw as { repo?: string; branch?: string } | undefined; + const prContext = + raw?.repo && raw?.branch ? { repo: raw.repo, branch: raw.branch } : undefined; + const state = await resolvePRState(thread, prContext); if (await handleFeedback(thread, message.text, state)) return; @@ -25,7 +33,11 @@ export function registerOnNewMention(bot: CodingAgentBot) { callbackThreadId: thread.id, }); - const card = buildTaskCard("Task Started", `Starting work on: "${prompt}"\n\nI'll reply here when done.`, handle.id); + const card = buildTaskCard( + "Task Started", + `Starting work on: "${prompt}"\n\nI'll reply here when done.`, + handle.id, + ); await thread.post({ card }); await thread.setState({ diff --git a/lib/coding-agent/prState/buildPRStateKey.ts b/lib/coding-agent/prState/buildPRStateKey.ts new file mode 100644 index 00000000..570a44ab --- /dev/null +++ b/lib/coding-agent/prState/buildPRStateKey.ts @@ -0,0 +1,11 @@ +const KEY_PREFIX = "coding-agent:pr"; + +/** + * Builds the Redis key for a given repo and branch. + * + * @param repo + * @param branch + */ +export function buildPRStateKey(repo: string, branch: string): string { + return `${KEY_PREFIX}:${repo}:${branch}`; +} diff --git a/lib/coding-agent/prState/deleteCodingAgentPRState.ts b/lib/coding-agent/prState/deleteCodingAgentPRState.ts new file mode 100644 index 00000000..87cf0eef --- /dev/null +++ b/lib/coding-agent/prState/deleteCodingAgentPRState.ts @@ -0,0 +1,13 @@ +import redis from "@/lib/redis/connection"; +import { buildPRStateKey } from "./buildPRStateKey"; + +/** + * Deletes the shared PR state for a repo/branch from Redis. + * + * @param repo + * @param branch + */ +export async function deleteCodingAgentPRState(repo: string, branch: string): Promise { + const key = buildPRStateKey(repo, branch); + await redis.del(key); +} diff --git a/lib/coding-agent/prState/getCodingAgentPRState.ts b/lib/coding-agent/prState/getCodingAgentPRState.ts new file mode 100644 index 00000000..69e68753 --- /dev/null +++ b/lib/coding-agent/prState/getCodingAgentPRState.ts @@ -0,0 +1,19 @@ +import redis from "@/lib/redis/connection"; +import { buildPRStateKey } from "./buildPRStateKey"; +import type { CodingAgentPRState } from "./types"; + +/** + * Gets the shared PR state for a repo/branch from Redis. + * + * @param repo + * @param branch + */ +export async function getCodingAgentPRState( + repo: string, + branch: string, +): Promise { + const key = buildPRStateKey(repo, branch); + const raw = await redis.get(key); + if (!raw) return null; + return JSON.parse(raw) as CodingAgentPRState; +} diff --git a/lib/coding-agent/prState/index.ts b/lib/coding-agent/prState/index.ts new file mode 100644 index 00000000..a5afc875 --- /dev/null +++ b/lib/coding-agent/prState/index.ts @@ -0,0 +1,5 @@ +export { buildPRStateKey } from "./buildPRStateKey"; +export { getCodingAgentPRState } from "./getCodingAgentPRState"; +export { setCodingAgentPRState } from "./setCodingAgentPRState"; +export { deleteCodingAgentPRState } from "./deleteCodingAgentPRState"; +export type { CodingAgentPRState } from "./types"; diff --git a/lib/coding-agent/prState/setCodingAgentPRState.ts b/lib/coding-agent/prState/setCodingAgentPRState.ts new file mode 100644 index 00000000..f7ffaac6 --- /dev/null +++ b/lib/coding-agent/prState/setCodingAgentPRState.ts @@ -0,0 +1,19 @@ +import redis from "@/lib/redis/connection"; +import { buildPRStateKey } from "./buildPRStateKey"; +import type { CodingAgentPRState } from "./types"; + +/** + * Sets the shared PR state for a repo/branch in Redis. + * + * @param repo + * @param branch + * @param state + */ +export async function setCodingAgentPRState( + repo: string, + branch: string, + state: CodingAgentPRState, +): Promise { + const key = buildPRStateKey(repo, branch); + await redis.set(key, JSON.stringify(state)); +} diff --git a/lib/coding-agent/prState/types.ts b/lib/coding-agent/prState/types.ts new file mode 100644 index 00000000..84fc8c16 --- /dev/null +++ b/lib/coding-agent/prState/types.ts @@ -0,0 +1,9 @@ +import type { CodingAgentPR } from "../types"; + +export interface CodingAgentPRState { + status: "running" | "pr_created" | "updating" | "merged" | "failed" | "no_changes"; + snapshotId?: string; + branch: string; + repo: string; + prs?: CodingAgentPR[]; +} diff --git a/lib/coding-agent/resolvePRState.ts b/lib/coding-agent/resolvePRState.ts new file mode 100644 index 00000000..78f8bb2d --- /dev/null +++ b/lib/coding-agent/resolvePRState.ts @@ -0,0 +1,40 @@ +import type { Thread } from "chat"; +import { getCodingAgentPRState, type CodingAgentPRState } from "./prState"; +import type { CodingAgentThreadState } from "./types"; + +export interface PRContext { + repo?: string; + branch?: string; +} + +/** + * Resolves the coding agent state from either the thread state or the shared PR state key. + * When a GitHub PR comment triggers onNewMention, the thread may not have state yet, + * but we can look up the shared key using repo/branch from the PR webhook context. + * + * @param thread - The chat thread + * @param prContext - Optional PR context with repo/branch (from GitHub webhook) + * @returns The thread state (preferred) or shared PR state, or null + */ +export async function resolvePRState( + thread: Thread, + prContext?: PRContext, +): Promise { + const threadState = await thread.state; + if (threadState) return threadState; + + if (prContext?.repo && prContext?.branch) { + const prState = await getCodingAgentPRState(prContext.repo, prContext.branch); + if (prState) { + return { + status: prState.status, + prompt: "", + branch: prState.branch, + snapshotId: prState.snapshotId, + prs: prState.prs, + }; + } + } + + return null; +}