diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts index 4b9b840a39..6078e38d92 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts @@ -70,6 +70,7 @@ import { } from "./vars"; import { workflowAccessActor, + workflowBatchJoinActor, workflowCounterActor, workflowQueueActor, workflowSleepActor, @@ -166,6 +167,7 @@ export const registry = setup({ workflowAccessActor, workflowSleepActor, workflowStopTeardownActor, + workflowBatchJoinActor, // From actor-db-raw.ts dbActorRaw, // From actor-db-drizzle.ts diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts index 28b6637d5f..cf8a61a573 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts @@ -226,10 +226,64 @@ async function updateWorkflowAccessState( typeof client.workflowQueueActor.getForId === "function"; } +const WORKFLOW_BATCH_QUEUE_NAME = "batch-requests"; + +export const workflowBatchJoinActor = actor({ + state: { + processedRows: [] as number[], + processedCells: [] as string[], + requestsCompleted: 0, + }, + queues: { + [WORKFLOW_BATCH_QUEUE_NAME]: queue<{ rowIds: number[] }>(), + }, + run: workflow(async (ctx) => { + await ctx.loop("request-loop", async (loopCtx) => { + const request = await loopCtx.queue.next("wait-request", { + names: [WORKFLOW_BATCH_QUEUE_NAME], + }); + + const rowIds = request.body.rowIds; + + // Fan out all rows in a single join. + const branches = Object.fromEntries( + rowIds.map((rowId, i) => [ + `row-${i}`, + { + run: async (branchCtx: WorkflowLoopContextOf) => { + await branchCtx.step(`cell-a-${rowId}`, async () => { + branchCtx.state.processedCells.push(`a-${rowId}`); + }); + await branchCtx.step(`cell-b-${rowId}`, async () => { + branchCtx.state.processedCells.push(`b-${rowId}`); + }); + await branchCtx.step(`cell-c-${rowId}`, async () => { + branchCtx.state.processedRows.push(rowId); + branchCtx.state.processedCells.push(`c-${rowId}`); + }); + }, + }, + ]), + ); + + await loopCtx.join("process-rows", branches); + + await loopCtx.step("request-done", async () => { + loopCtx.state.requestsCompleted += 1; + }); + + return Loop.continue(undefined); + }); + }), + actions: { + getState: (c) => c.state, + }, +}); + function incrementWorkflowSleepTick( ctx: WorkflowLoopContextOf, ): void { ctx.state.ticks += 1; } -export { WORKFLOW_QUEUE_NAME }; +export { WORKFLOW_QUEUE_NAME, WORKFLOW_BATCH_QUEUE_NAME }; diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts index 419899e9c8..e1c9d9a68c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts @@ -1,5 +1,6 @@ import { describe, expect, test } from "vitest"; import { + WORKFLOW_BATCH_QUEUE_NAME, WORKFLOW_QUEUE_NAME, } from "../../../fixtures/driver-test-suite/workflow"; import type { DriverTestConfig } from "../mod"; @@ -110,6 +111,54 @@ export function runActorWorkflowTests(driverTestConfig: DriverTestConfig) { }, ); + test("join fans out rows in parallel inside loop", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + const actor = client.workflowBatchJoinActor.getOrCreate([ + "workflow-batch-join", + ]); + + await actor.send(WORKFLOW_BATCH_QUEUE_NAME, { + rowIds: [1, 2, 3, 4], + }); + + let state = await actor.getState(); + for (let i = 0; i < 50; i++) { + if (state.requestsCompleted >= 1) break; + await waitFor(driverTestConfig, 100); + state = await actor.getState(); + } + + expect(state.requestsCompleted).toBe(1); + expect(state.processedRows.sort()).toEqual([1, 2, 3, 4]); + expect(state.processedCells.length).toBe(12); + }); + + test("join handles sequential queue requests", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + const actor = client.workflowBatchJoinActor.getOrCreate([ + "workflow-batch-join-sequential", + ]); + + await actor.send(WORKFLOW_BATCH_QUEUE_NAME, { + rowIds: [1, 2], + }); + await actor.send(WORKFLOW_BATCH_QUEUE_NAME, { + rowIds: [3, 4], + }); + + let state = await actor.getState(); + for (let i = 0; i < 50; i++) { + if (state.requestsCompleted >= 2) break; + await waitFor(driverTestConfig, 100); + state = await actor.getState(); + } + + expect(state.requestsCompleted).toBeGreaterThanOrEqual(2); + expect(state.processedRows).toEqual( + expect.arrayContaining([1, 2, 3, 4]), + ); + }); + // NOTE: Test for workflow persistence across actor sleep is complex because // calling c.sleep() during a workflow prevents clean shutdown. The workflow // persistence is implicitly tested by the "sleeps and resumes between ticks" diff --git a/rivetkit-typescript/packages/workflow-engine/src/context.ts b/rivetkit-typescript/packages/workflow-engine/src/context.ts index 3cd294593a..aca1cbadd1 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/context.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/context.ts @@ -209,6 +209,31 @@ export class WorkflowContextImpl implements WorkflowContextInterface { this.visitedKeys.add(key); } + /** + * Merge visited keys from a child branch context into this context. + * This ensures that entries validated by nested branches are also + * recognized as visited by the parent scope's validateComplete. + */ + mergeVisitedKeys(child: WorkflowContextImpl): void { + for (const key of child.visitedKeys) { + this.visitedKeys.add(key); + } + } + + /** + * Mark all history entries under a location prefix as visited. + * Used when replaying completed branches that are skipped during + * re-execution so their child entries don't trigger validateComplete errors. + */ + private markAllEntriesVisited(location: Location): void { + const prefix = locationToKey(this.storage, location); + for (const key of this.storage.history.entries.keys()) { + if (key.startsWith(prefix + "/") || key === prefix) { + this.visitedKeys.add(key); + } + } + } + /** * Check if a name has already been used at the current location in this execution. * Throws HistoryDivergedError if duplicate detected. @@ -737,6 +762,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface { // Validate branch completed cleanly branchCtx.validateComplete(); + this.mergeVisitedKeys(branchCtx); if ("break" in result && result.break) { // Loop complete @@ -1490,24 +1516,27 @@ export class WorkflowContextImpl implements WorkflowContextInterface { async ([branchName, config]) => { const branchStatus = joinData.branches[branchName]; + const branchLocation = appendName( + this.storage, + location, + branchName, + ); + // Already completed if (branchStatus.status === "completed") { + this.markAllEntriesVisited(branchLocation); results[branchName] = branchStatus.output; return; } // Already failed if (branchStatus.status === "failed") { + this.markAllEntriesVisited(branchLocation); errors[branchName] = new Error(branchStatus.error); return; } // Execute branch - const branchLocation = appendName( - this.storage, - location, - branchName, - ); const branchCtx = this.createBranch(branchLocation); branchStatus.status = "running"; @@ -1516,6 +1545,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface { try { const output = await config.run(branchCtx); branchCtx.validateComplete(); + this.mergeVisitedKeys(branchCtx); branchStatus.status = "completed"; branchStatus.output = output; @@ -1705,6 +1735,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface { winnerValue = output; branchCtx.validateComplete(); + this.mergeVisitedKeys(branchCtx); branchStatus.status = "completed"; branchStatus.output = output;