diff --git a/app/api/coding-agent/[platform]/route.ts b/app/api/coding-agent/[platform]/route.ts index 5e95bff8..dd79629b 100644 --- a/app/api/coding-agent/[platform]/route.ts +++ b/app/api/coding-agent/[platform]/route.ts @@ -7,11 +7,11 @@ import "@/lib/coding-agent/handlers/registerHandlers"; * POST /api/coding-agent/[platform] * * Webhook endpoint for the coding agent bot. - * Handles Slack webhooks via dynamic [platform] segment. + * Currently handles Slack webhooks via dynamic [platform] segment. * * @param request - The incoming webhook request * @param params.params - * @param params - Route params containing the platform name (slack or github) + * @param params - Route params containing the platform name */ export async function POST( request: NextRequest, diff --git a/app/api/coding-agent/__tests__/route.test.ts b/app/api/coding-agent/__tests__/route.test.ts index d23eb468..bc34ced2 100644 --- a/app/api/coding-agent/__tests__/route.test.ts +++ b/app/api/coding-agent/__tests__/route.test.ts @@ -13,7 +13,6 @@ vi.mock("@/lib/coding-agent/bot", () => ({ codingAgentBot: { webhooks: { slack: vi.fn().mockResolvedValue(new Response("ok", { status: 200 })), - github: vi.fn().mockResolvedValue(new Response("ok", { status: 200 })), }, }, })); diff --git a/app/api/coding-agent/callback/route.ts b/app/api/coding-agent/callback/route.ts index 46349018..bb029f13 100644 --- a/app/api/coding-agent/callback/route.ts +++ b/app/api/coding-agent/callback/route.ts @@ -1,4 +1,6 @@ import type { NextRequest } from "next/server"; +import redis from "@/lib/redis/connection"; +import "@/lib/coding-agent/bot"; import { handleCodingAgentCallback } from "@/lib/coding-agent/handleCodingAgentCallback"; /** @@ -10,5 +12,15 @@ import { handleCodingAgentCallback } from "@/lib/coding-agent/handleCodingAgentC * @param request - The incoming callback request */ export async function POST(request: NextRequest) { + if (redis.status !== "ready") { + if (redis.status === "wait") { + await redis.connect(); + } else { + await new Promise((resolve, reject) => { + redis.once("ready", resolve); + redis.once("error", reject); + }); + } + } return handleCodingAgentCallback(request); } diff --git a/lib/coding-agent/__tests__/bot.test.ts b/lib/coding-agent/__tests__/bot.test.ts index 3d320f1b..ede2020d 100644 --- a/lib/coding-agent/__tests__/bot.test.ts +++ b/lib/coding-agent/__tests__/bot.test.ts @@ -38,6 +38,7 @@ describe("createCodingAgentBot", () => { vi.clearAllMocks(); process.env.SLACK_BOT_TOKEN = "xoxb-test"; process.env.SLACK_SIGNING_SECRET = "test-signing-secret"; + process.env.GITHUB_TOKEN = "ghp_test"; process.env.REDIS_URL = "redis://localhost:6379"; process.env.CODING_AGENT_CALLBACK_SECRET = "test-callback-secret"; }); diff --git a/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts b/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts index b0173c9e..64e30c7b 100644 --- a/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts +++ b/lib/coding-agent/__tests__/handleCodingAgentCallback.test.ts @@ -6,11 +6,15 @@ vi.mock("@/lib/networking/getCorsHeaders", () => ({ const mockPost = vi.fn(); const mockSetState = vi.fn(); +let mockState: unknown = null; vi.mock("chat", () => { const ThreadImpl = vi.fn().mockImplementation(() => ({ post: mockPost, setState: mockSetState, + get state() { + return Promise.resolve(mockState); + }, })); return { ThreadImpl, @@ -18,6 +22,11 @@ vi.mock("chat", () => { const parts = threadId.split(":"); return `${parts[0]}:${parts[1]}`; }), + Card: vi.fn((opts) => ({ type: "card", ...opts })), + CardText: vi.fn((text) => ({ type: "text", text })), + Actions: vi.fn((children) => ({ type: "actions", children })), + Button: vi.fn((opts) => ({ type: "button", ...opts })), + LinkButton: vi.fn((opts) => ({ type: "link-button", ...opts })), }; }); @@ -29,6 +38,7 @@ const { handleCodingAgentCallback } = await import("../handleCodingAgentCallback beforeEach(() => { vi.clearAllMocks(); + mockState = null; process.env.CODING_AGENT_CALLBACK_SECRET = "test-secret"; }); @@ -89,11 +99,11 @@ describe("handleCodingAgentCallback", () => { const response = await handleCodingAgentCallback(request); expect(response.status).toBe(200); - expect(mockPost).toHaveBeenCalled(); + expect(mockPost).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); expect(mockSetState).toHaveBeenCalledWith(expect.objectContaining({ status: "pr_created" })); }); - it("posts no-changes message for no_changes status", async () => { + it("posts no-changes message and resets state for no_changes status", async () => { const body = { threadId: "slack:C123:1234567890.123456", status: "no_changes", @@ -104,10 +114,11 @@ describe("handleCodingAgentCallback", () => { const response = await handleCodingAgentCallback(request); expect(response.status).toBe(200); + expect(mockSetState).toHaveBeenCalledWith(expect.objectContaining({ status: "no_changes" })); expect(mockPost).toHaveBeenCalledWith(expect.stringContaining("No changes")); }); - it("posts error message for failed status", async () => { + it("posts error message and resets state for failed status", async () => { const body = { threadId: "slack:C123:1234567890.123456", status: "failed", @@ -118,7 +129,27 @@ describe("handleCodingAgentCallback", () => { const response = await handleCodingAgentCallback(request); expect(response.status).toBe(200); + expect(mockSetState).toHaveBeenCalledWith(expect.objectContaining({ status: "failed" })); expect(mockPost).toHaveBeenCalledWith(expect.stringContaining("Sandbox timed out")); }); + it("posts updated card with PR buttons for updated status", async () => { + mockState = { + status: "updating", + prs: [{ repo: "recoupable/api", number: 42, url: "https://github.com/recoupable/api/pull/42", baseBranch: "test" }], + }; + + const body = { + threadId: "slack:C123:1234567890.123456", + status: "updated", + snapshotId: "snap_new", + }; + const request = makeRequest(body); + + const response = await handleCodingAgentCallback(request); + + expect(response.status).toBe(200); + expect(mockSetState).toHaveBeenCalledWith(expect.objectContaining({ status: "pr_created", snapshotId: "snap_new" })); + expect(mockPost).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); + }); }); diff --git a/lib/coding-agent/__tests__/handlePRCreated.test.ts b/lib/coding-agent/__tests__/handlePRCreated.test.ts index 8208c8bb..6f80972a 100644 --- a/lib/coding-agent/__tests__/handlePRCreated.test.ts +++ b/lib/coding-agent/__tests__/handlePRCreated.test.ts @@ -9,8 +9,16 @@ vi.mock("../getThread", () => ({ getThread: vi.fn(() => mockThread), })); +vi.mock("chat", () => ({ + Card: vi.fn((opts) => ({ type: "card", ...opts })), + CardText: vi.fn((text) => ({ type: "text", text })), + Actions: vi.fn((children) => ({ type: "actions", children })), + Button: vi.fn((opts) => ({ type: "button", ...opts })), + LinkButton: vi.fn((opts) => ({ type: "link-button", ...opts })), +})); + describe("handlePRCreated", () => { - it("posts PR links and updates thread state", async () => { + it("posts a card with PR links and merge button", async () => { const { handlePRCreated } = await import("../handlePRCreated"); await handlePRCreated("slack:C123:ts", { @@ -21,7 +29,13 @@ describe("handlePRCreated", () => { prs: [{ repo: "recoupable/api", number: 42, url: "https://github.com/recoupable/api/pull/42", baseBranch: "test" }], }); - expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("recoupable/api#42")); + expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); + + const { Button } = await import("chat"); + expect(Button).toHaveBeenCalledWith( + expect.objectContaining({ id: "merge_all_prs", label: "Merge All PRs" }), + ); + expect(mockThread.setState).toHaveBeenCalledWith( expect.objectContaining({ status: "pr_created", diff --git a/lib/coding-agent/__tests__/handlers.test.ts b/lib/coding-agent/__tests__/handlers.test.ts index e750d149..5cd30f50 100644 --- a/lib/coding-agent/__tests__/handlers.test.ts +++ b/lib/coding-agent/__tests__/handlers.test.ts @@ -1,7 +1,21 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; +const mockTriggerCodingAgent = vi.fn().mockResolvedValue({ id: "run_123" }); +const mockTriggerUpdatePR = vi.fn().mockResolvedValue({ id: "run_456" }); + vi.mock("@/lib/trigger/triggerCodingAgent", () => ({ - triggerCodingAgent: vi.fn().mockResolvedValue({ id: "run_123" }), + triggerCodingAgent: mockTriggerCodingAgent, +})); + +vi.mock("@/lib/trigger/triggerUpdatePR", () => ({ + triggerUpdatePR: mockTriggerUpdatePR, +})); + +vi.mock("chat", () => ({ + Card: vi.fn((opts) => ({ type: "card", ...opts })), + CardText: vi.fn((text) => ({ type: "text", text })), + Actions: vi.fn((children) => ({ type: "actions", children })), + LinkButton: vi.fn((opts) => ({ type: "link-button", ...opts })), })); const { registerOnNewMention } = await import("../handlers/onNewMention"); @@ -26,13 +40,14 @@ describe("registerOnNewMention", () => { expect(bot.onNewMention).toHaveBeenCalledOnce(); }); - it("posts acknowledgment and triggers coding agent task", async () => { + it("posts acknowledgment and triggers coding agent task when no existing state", async () => { const bot = createMockBot(); registerOnNewMention(bot); const handler = bot.onNewMention.mock.calls[0][0]; const mockThread = { id: "slack:C123:1234567890.123456", + state: Promise.resolve(null), subscribe: vi.fn(), post: vi.fn(), setState: vi.fn(), @@ -45,7 +60,8 @@ describe("registerOnNewMention", () => { await handler(mockThread, mockMessage); expect(mockThread.subscribe).toHaveBeenCalledOnce(); - expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("Starting work")); + expect(mockTriggerCodingAgent).toHaveBeenCalled(); + expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); expect(mockThread.setState).toHaveBeenCalledWith( expect.objectContaining({ status: "running", @@ -53,4 +69,63 @@ describe("registerOnNewMention", () => { }), ); }); + + it("triggers update-pr instead of coding-agent when thread has pr_created state", async () => { + const bot = createMockBot(); + registerOnNewMention(bot); + const handler = bot.onNewMention.mock.calls[0][0]; + + const mockThread = { + id: "github:recoupable/tasks:56", + state: Promise.resolve({ + status: "pr_created", + prompt: "original prompt", + snapshotId: "snap_abc", + branch: "agent/fix-bug", + prs: [{ repo: "recoupable/tasks", number: 56, url: "https://github.com/recoupable/tasks/pull/56", baseBranch: "main" }], + }), + subscribe: vi.fn(), + post: vi.fn(), + setState: vi.fn(), + }; + const mockMessage = { + text: "remove the Project Structure changes", + author: { id: "sweetmantech" }, + }; + + await handler(mockThread, mockMessage); + + expect(mockTriggerCodingAgent).not.toHaveBeenCalled(); + expect(mockTriggerUpdatePR).toHaveBeenCalledWith( + expect.objectContaining({ + feedback: "remove the Project Structure changes", + snapshotId: "snap_abc", + branch: "agent/fix-bug", + repo: "recoupable/tasks", + }), + ); + expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); + expect(mockThread.setState).toHaveBeenCalledWith(expect.objectContaining({ status: "updating" })); + }); + + it("tells user to wait when thread is already running", async () => { + const bot = createMockBot(); + registerOnNewMention(bot); + const handler = bot.onNewMention.mock.calls[0][0]; + + const mockThread = { + id: "github:recoupable/tasks:56", + state: Promise.resolve({ status: "running", prompt: "original" }), + subscribe: vi.fn(), + post: vi.fn(), + setState: vi.fn(), + }; + const mockMessage = { text: "any update?", author: { id: "sweetmantech" } }; + + await handler(mockThread, mockMessage); + + expect(mockTriggerCodingAgent).not.toHaveBeenCalled(); + expect(mockTriggerUpdatePR).not.toHaveBeenCalled(); + expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("still working")); + }); }); diff --git a/lib/coding-agent/__tests__/onMergeAction.test.ts b/lib/coding-agent/__tests__/onMergeAction.test.ts new file mode 100644 index 00000000..2b249cf5 --- /dev/null +++ b/lib/coding-agent/__tests__/onMergeAction.test.ts @@ -0,0 +1,68 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +global.fetch = vi.fn(); + +const { registerOnMergeAction } = await import("../handlers/onMergeAction"); + +beforeEach(() => { + vi.clearAllMocks(); + process.env.GITHUB_TOKEN = "ghp_test"; +}); + +function createMockBot() { + return { + onAction: vi.fn(), + } as any; +} + +describe("registerOnMergeAction", () => { + it("registers merge_all_prs action handler", () => { + const bot = createMockBot(); + registerOnMergeAction(bot); + expect(bot.onAction).toHaveBeenCalledWith("merge_all_prs", expect.any(Function)); + }); + + it("squash-merges PRs and posts results", async () => { + vi.mocked(fetch).mockResolvedValue({ ok: true } as Response); + + const bot = createMockBot(); + registerOnMergeAction(bot); + const handler = bot.onAction.mock.calls[0][1]; + + const mockThread = { + state: Promise.resolve({ + status: "pr_created", + prompt: "fix bug", + prs: [{ repo: "recoupable/api", number: 42, url: "url", baseBranch: "test" }], + }), + post: vi.fn(), + setState: vi.fn(), + }; + + await handler({ thread: mockThread }); + + expect(fetch).toHaveBeenCalledWith( + "https://api.github.com/repos/recoupable/api/pulls/42/merge", + expect.objectContaining({ method: "PUT" }), + ); + expect(mockThread.setState).toHaveBeenCalledWith({ status: "merged" }); + expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("merged")); + }); + + it("posts no PRs message when state has no PRs", async () => { + const bot = createMockBot(); + registerOnMergeAction(bot); + const handler = bot.onAction.mock.calls[0][1]; + + const mockThread = { + state: Promise.resolve({ status: "pr_created", prompt: "fix bug" }), + post: vi.fn(), + setState: vi.fn(), + }; + + await handler({ thread: mockThread }); + + expect(mockThread.post).toHaveBeenCalledWith("No PRs to merge."); + expect(fetch).not.toHaveBeenCalled(); + }); +}); diff --git a/lib/coding-agent/__tests__/onSubscribedMessage.test.ts b/lib/coding-agent/__tests__/onSubscribedMessage.test.ts index f4647420..8ba6b565 100644 --- a/lib/coding-agent/__tests__/onSubscribedMessage.test.ts +++ b/lib/coding-agent/__tests__/onSubscribedMessage.test.ts @@ -45,7 +45,7 @@ describe("registerOnSubscribedMessage", () => { await handler(mockThread, { text: "make the button blue", author: { userId: "U111" } }); - expect(mockThread.post).toHaveBeenCalledWith(expect.stringContaining("feedback")); + expect(mockThread.post).toHaveBeenCalledWith(expect.objectContaining({ card: expect.anything() })); expect(mockThread.setState).toHaveBeenCalledWith(expect.objectContaining({ status: "updating" })); expect(triggerUpdatePR).toHaveBeenCalledWith( expect.objectContaining({ diff --git a/lib/coding-agent/__tests__/validateCodingAgentCallback.test.ts b/lib/coding-agent/__tests__/validateCodingAgentCallback.test.ts index eccff62e..525aa4ff 100644 --- a/lib/coding-agent/__tests__/validateCodingAgentCallback.test.ts +++ b/lib/coding-agent/__tests__/validateCodingAgentCallback.test.ts @@ -44,7 +44,16 @@ describe("validateCodingAgentCallback", () => { expect(result).not.toBeInstanceOf(NextResponse); }); -}); + it("accepts updated status with new snapshotId", () => { + const body = { + threadId: "slack:C123:1234567890.123456", + status: "updated", + snapshotId: "snap_new456", + }; + const result = validateCodingAgentCallback(body); + expect(result).not.toBeInstanceOf(NextResponse); + }); + }); describe("invalid payloads", () => { it("rejects missing threadId", () => { diff --git a/lib/coding-agent/__tests__/validateEnv.test.ts b/lib/coding-agent/__tests__/validateEnv.test.ts index 27391354..37278c74 100644 --- a/lib/coding-agent/__tests__/validateEnv.test.ts +++ b/lib/coding-agent/__tests__/validateEnv.test.ts @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; const REQUIRED_VARS = [ "SLACK_BOT_TOKEN", "SLACK_SIGNING_SECRET", + "GITHUB_TOKEN", "REDIS_URL", "CODING_AGENT_CALLBACK_SECRET", ]; diff --git a/lib/coding-agent/buildPRCard.ts b/lib/coding-agent/buildPRCard.ts new file mode 100644 index 00000000..9ca6343d --- /dev/null +++ b/lib/coding-agent/buildPRCard.ts @@ -0,0 +1,23 @@ +import { Card, CardText, Actions, Button, LinkButton } from "chat"; +import type { CodingAgentPR } from "./types"; + +/** + * Builds a Card with PR review links and a Merge All PRs button. + * + * @param title - Card title (e.g. "PRs Created", "PRs Updated") + * @param prs - Array of PRs to build review links for + */ +export function buildPRCard(title: string, prs: CodingAgentPR[]) { + return Card({ + title, + children: [ + CardText(`${prs.map(pr => `- ${pr.repo}#${pr.number} → \`${pr.baseBranch}\``).join("\n")}\n\nReply in this thread to give feedback.`), + Actions([ + ...prs.map(pr => + LinkButton({ url: pr.url, label: `Review ${pr.repo}#${pr.number}` }), + ), + Button({ id: "merge_all_prs", label: "Merge All PRs", style: "primary" }), + ]), + ], + }); +} diff --git a/lib/coding-agent/buildTaskCard.ts b/lib/coding-agent/buildTaskCard.ts new file mode 100644 index 00000000..f0b16548 --- /dev/null +++ b/lib/coding-agent/buildTaskCard.ts @@ -0,0 +1,20 @@ +import { Card, CardText, Actions, LinkButton } from "chat"; + +/** + * Builds a Card with a message and a View Task button. + * + * @param title - Card title (e.g. "Task Started", "Updating PRs") + * @param message - Body text + * @param runId - Trigger.dev run ID for the View Task link + */ +export function buildTaskCard(title: string, message: string, runId: string) { + return Card({ + title, + children: [ + CardText(message), + Actions([ + LinkButton({ url: `https://chat.recoupable.com/tasks/${runId}`, label: "View Task" }), + ]), + ], + }); +} diff --git a/lib/coding-agent/handleCodingAgentCallback.ts b/lib/coding-agent/handleCodingAgentCallback.ts index 3796780d..5397419a 100644 --- a/lib/coding-agent/handleCodingAgentCallback.ts +++ b/lib/coding-agent/handleCodingAgentCallback.ts @@ -3,6 +3,8 @@ import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import { validateCodingAgentCallback } from "./validateCodingAgentCallback"; import { getThread } from "./getThread"; import { handlePRCreated } from "./handlePRCreated"; +import { buildPRCard } from "./buildPRCard"; +import type { CodingAgentThreadState } from "./types"; /** * Handles coding agent task callback from Trigger.dev. @@ -46,12 +48,23 @@ export async function handleCodingAgentCallback(request: Request): Promise `- [${pr.repo}#${pr.number}](${pr.url}) → \`${pr.baseBranch}\``) - .join("\n"); + const card = buildPRCard("PRs Created", body.prs ?? []); - await thread.post( - `PRs created:\n${prLinks}\n\nReply in this thread to give feedback, or click Merge when ready.`, - ); + await thread.post({ card }); await thread.setState({ status: "pr_created", diff --git a/lib/coding-agent/handlers/handleFeedback.ts b/lib/coding-agent/handlers/handleFeedback.ts new file mode 100644 index 00000000..4c1c554d --- /dev/null +++ b/lib/coding-agent/handlers/handleFeedback.ts @@ -0,0 +1,40 @@ +import type { Thread } from "chat"; +import { buildTaskCard } from "../buildTaskCard"; +import { triggerUpdatePR } from "@/lib/trigger/triggerUpdatePR"; +import type { CodingAgentThreadState } from "../types"; + +/** + * Handles a message in a thread that already has state. + * Returns true if the message was handled (busy or feedback), false otherwise. + * + * @param thread - The chat thread + * @param messageText - The user's message text + * @param state - The current thread state + */ +export async function handleFeedback( + thread: Thread, + messageText: string, + state: CodingAgentThreadState | null, +): Promise { + if (state?.status === "running" || state?.status === "updating") { + await thread.post("I'm still working on this. I'll let you know when I'm done."); + return true; + } + + if (state?.status === "pr_created" && state.snapshotId && state.branch && state.prs?.length) { + await thread.setState({ status: "updating" }); + const handle = await triggerUpdatePR({ + feedback: messageText, + snapshotId: state.snapshotId, + branch: state.branch, + repo: state.prs[0].repo, + callbackThreadId: thread.id, + }); + + const card = buildTaskCard("Updating PRs", "Got your feedback. Updating the PRs...", handle.id); + await thread.post({ card }); + return true; + } + + return false; +} diff --git a/lib/coding-agent/handlers/onMergeAction.ts b/lib/coding-agent/handlers/onMergeAction.ts new file mode 100644 index 00000000..5e883db4 --- /dev/null +++ b/lib/coding-agent/handlers/onMergeAction.ts @@ -0,0 +1,54 @@ +import type { CodingAgentBot } from "../bot"; +import type { CodingAgentThreadState } from "../types"; + +/** + * Registers the "Merge All PRs" button action handler on the bot. + * Squash-merges each PR via the GitHub API. + * + * @param bot + */ +export function registerOnMergeAction(bot: CodingAgentBot) { + bot.onAction("merge_all_prs", async event => { + const thread = event.thread; + const state = (await thread.state) as CodingAgentThreadState | null; + + if (!state?.prs?.length) { + await thread.post("No PRs to merge."); + return; + } + + const token = process.env.GITHUB_TOKEN; + if (!token) { + await thread.post("Missing GITHUB_TOKEN — cannot merge PRs."); + return; + } + + const results: string[] = []; + + for (const pr of state.prs) { + const [owner, repo] = pr.repo.split("/"); + const response = await fetch( + `https://api.github.com/repos/${owner}/${repo}/pulls/${pr.number}/merge`, + { + method: "PUT", + headers: { + Authorization: `Bearer ${token}`, + Accept: "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + }, + body: JSON.stringify({ merge_method: "squash" }), + }, + ); + + if (response.ok) { + results.push(`${pr.repo}#${pr.number} merged`); + } else { + const error = await response.json(); + results.push(`${pr.repo}#${pr.number} failed: ${error.message}`); + } + } + + await thread.setState({ status: "merged" }); + await thread.post(`Merge results:\n${results.map(r => `- ${r}`).join("\n")}`); + }); +} diff --git a/lib/coding-agent/handlers/onNewMention.ts b/lib/coding-agent/handlers/onNewMention.ts index 20ecd15e..ea020146 100644 --- a/lib/coding-agent/handlers/onNewMention.ts +++ b/lib/coding-agent/handlers/onNewMention.ts @@ -1,26 +1,33 @@ import type { CodingAgentBot } from "../bot"; +import { buildTaskCard } from "../buildTaskCard"; import { triggerCodingAgent } from "@/lib/trigger/triggerCodingAgent"; +import { handleFeedback } from "./handleFeedback"; /** * Registers the onNewMention handler on the bot. - * Subscribes to the thread and triggers the coding agent Trigger.dev task. + * If the thread already has PRs, treats the mention as feedback and + * triggers the update-pr task. Otherwise, starts a new coding agent task. * * @param bot */ export function registerOnNewMention(bot: CodingAgentBot) { bot.onNewMention(async (thread, message) => { - const prompt = message.text; - try { - await thread.subscribe(); + const state = await thread.state; - await thread.post(`Starting work on: "${prompt}"\n\nI'll reply here when done.`); + if (await handleFeedback(thread, message.text, state)) return; + + const prompt = message.text; + await thread.subscribe(); const handle = await triggerCodingAgent({ prompt, callbackThreadId: thread.id, }); + const card = buildTaskCard("Task Started", `Starting work on: "${prompt}"\n\nI'll reply here when done.`, handle.id); + await thread.post({ card }); + await thread.setState({ status: "running", prompt, diff --git a/lib/coding-agent/handlers/onSubscribedMessage.ts b/lib/coding-agent/handlers/onSubscribedMessage.ts index 1846c541..7b769706 100644 --- a/lib/coding-agent/handlers/onSubscribedMessage.ts +++ b/lib/coding-agent/handlers/onSubscribedMessage.ts @@ -1,37 +1,16 @@ import type { CodingAgentBot } from "../bot"; -import { triggerUpdatePR } from "@/lib/trigger/triggerUpdatePR"; +import { handleFeedback } from "./handleFeedback"; /** * Registers the onSubscribedMessage handler on the bot. - * If the agent has created PRs, treats the message as feedback and - * triggers the update-pr task. If the agent is currently working, - * tells the user to wait. + * Delegates to handleFeedback for busy/update-pr logic. * * @param bot */ export function registerOnSubscribedMessage(bot: CodingAgentBot) { bot.onSubscribedMessage(async (thread, message) => { const state = await thread.state; - if (!state) return; - - if (state.status === "running" || state.status === "updating") { - await thread.post("I'm still working on this. I'll let you know when I'm done."); - return; - } - - if (state.status === "pr_created" && state.snapshotId && state.branch && state.prs?.length) { - await thread.post("Got your feedback. Updating the PRs..."); - - await thread.setState({ status: "updating" }); - - await triggerUpdatePR({ - feedback: message.text, - snapshotId: state.snapshotId, - branch: state.branch, - repo: state.prs[0].repo, - callbackThreadId: thread.id, - }); - } + await handleFeedback(thread, message.text, state); }); } diff --git a/lib/coding-agent/handlers/registerHandlers.ts b/lib/coding-agent/handlers/registerHandlers.ts index 788e4e46..96f24748 100644 --- a/lib/coding-agent/handlers/registerHandlers.ts +++ b/lib/coding-agent/handlers/registerHandlers.ts @@ -1,6 +1,7 @@ import { codingAgentBot } from "../bot"; import { registerOnNewMention } from "./onNewMention"; import { registerOnSubscribedMessage } from "./onSubscribedMessage"; +import { registerOnMergeAction } from "./onMergeAction"; /** * Registers all coding agent event handlers on the bot singleton. @@ -8,3 +9,4 @@ import { registerOnSubscribedMessage } from "./onSubscribedMessage"; */ registerOnNewMention(codingAgentBot); registerOnSubscribedMessage(codingAgentBot); +registerOnMergeAction(codingAgentBot); diff --git a/lib/coding-agent/types.ts b/lib/coding-agent/types.ts index 28a583fe..8092e66b 100644 --- a/lib/coding-agent/types.ts +++ b/lib/coding-agent/types.ts @@ -3,7 +3,7 @@ * Stored in Redis via Chat SDK's state adapter. */ export interface CodingAgentThreadState { - status: "running" | "pr_created" | "updating" | "failed"; + status: "running" | "pr_created" | "updating" | "merged" | "failed" | "no_changes"; prompt: string; runId?: string; slackThreadId?: string; diff --git a/lib/coding-agent/validateCodingAgentCallback.ts b/lib/coding-agent/validateCodingAgentCallback.ts index eae9097e..57989f04 100644 --- a/lib/coding-agent/validateCodingAgentCallback.ts +++ b/lib/coding-agent/validateCodingAgentCallback.ts @@ -11,7 +11,7 @@ const codingAgentPRSchema = z.object({ export const codingAgentCallbackSchema = z.object({ threadId: z.string({ message: "threadId is required" }).min(1, "threadId cannot be empty"), - status: z.enum(["pr_created", "no_changes", "failed"]), + status: z.enum(["pr_created", "no_changes", "failed", "updated"]), branch: z.string().optional(), snapshotId: z.string().optional(), prs: z.array(codingAgentPRSchema).optional(), diff --git a/lib/coding-agent/validateEnv.ts b/lib/coding-agent/validateEnv.ts index 9e9a5af6..51e0a36c 100644 --- a/lib/coding-agent/validateEnv.ts +++ b/lib/coding-agent/validateEnv.ts @@ -1,6 +1,7 @@ const REQUIRED_ENV_VARS = [ "SLACK_BOT_TOKEN", "SLACK_SIGNING_SECRET", + "GITHUB_TOKEN", "REDIS_URL", "CODING_AGENT_CALLBACK_SECRET", ] as const;