diff --git a/examples/sandbox-vercel/src/actors/workflow/workflow-fixtures.ts b/examples/sandbox-vercel/src/actors/workflow/workflow-fixtures.ts index 1ce30c369c..8d7e06e951 100644 --- a/examples/sandbox-vercel/src/actors/workflow/workflow-fixtures.ts +++ b/examples/sandbox-vercel/src/actors/workflow/workflow-fixtures.ts @@ -55,12 +55,13 @@ export const workflowQueueActor = actor({ name: "queue", run: async (loopCtx) => { const actorLoopCtx = loopCtx as any; - const payload = await loopCtx.listen( + const message = await loopCtx.listen( "queue-wait", WORKFLOW_QUEUE_NAME, ); await loopCtx.step("store-message", async () => { - actorLoopCtx.state.received.push(payload); + actorLoopCtx.state.received.push(message.body); + await message.complete({ echo: message.body }); }); return Loop.continue(undefined); }, diff --git a/examples/sandbox/src/actors/workflow/workflow-fixtures.ts b/examples/sandbox/src/actors/workflow/workflow-fixtures.ts index 1ce30c369c..8d7e06e951 100644 --- a/examples/sandbox/src/actors/workflow/workflow-fixtures.ts +++ b/examples/sandbox/src/actors/workflow/workflow-fixtures.ts @@ -55,12 +55,13 @@ export const workflowQueueActor = actor({ name: "queue", run: async (loopCtx) => { const actorLoopCtx = loopCtx as any; - const payload = await loopCtx.listen( + const message = await loopCtx.listen( "queue-wait", WORKFLOW_QUEUE_NAME, ); await loopCtx.step("store-message", async () => { - actorLoopCtx.state.received.push(payload); + actorLoopCtx.state.received.push(message.body); + await message.complete({ echo: message.body }); }); return Loop.continue(undefined); }, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts index 4dc84f68aa..38b5f89132 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/registry.ts @@ -63,6 +63,7 @@ import { uniqueVarActor, } from "./vars"; import { + workflowAccessActor, workflowCounterActor, workflowQueueActor, workflowSleepActor, @@ -151,6 +152,7 @@ export const registry = setup({ // From workflow.ts workflowCounterActor, workflowQueueActor, + workflowAccessActor, workflowSleepActor, // From actor-db-raw.ts dbActorRaw, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts index ea027d3bd9..3038124b4b 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/workflow.ts @@ -1,7 +1,9 @@ import { Loop } from "@rivetkit/workflow-engine"; import { actor } from "@/actor/mod"; +import { db } from "@/db/mod"; import { WORKFLOW_GUARD_KV_KEY } from "@/workflow/constants"; import { workflow, workflowQueueName } from "@/workflow/mod"; +import type { registry } from "./registry"; const WORKFLOW_QUEUE_NAME = "workflow-default"; @@ -55,12 +57,13 @@ export const workflowQueueActor = actor({ name: "queue", run: async (loopCtx) => { const actorLoopCtx = loopCtx as any; - const payload = await loopCtx.listen( + const message = await loopCtx.listen( "queue-wait", WORKFLOW_QUEUE_NAME, ); await loopCtx.step("store-message", async () => { - actorLoopCtx.state.received.push(payload); + actorLoopCtx.state.received.push(message.body); + await message.complete({ echo: message.body }); }); return Loop.continue(undefined); }, @@ -68,6 +71,81 @@ export const workflowQueueActor = actor({ }), actions: { getMessages: (c) => c.state.received, + sendAndWait: async (c, payload: unknown) => { + const client = c.client(); + const handle = client.workflowQueueActor.getForId(c.actorId); + return await handle.queue[workflowQueueName(WORKFLOW_QUEUE_NAME)].send( + payload, + { wait: true, timeout: 1_000 }, + ); + }, + }, +}); + +export const workflowAccessActor = actor({ + db: db({ + onMigrate: async (rawDb) => { + await rawDb.execute(` + CREATE TABLE IF NOT EXISTS workflow_access_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at INTEGER NOT NULL + ) + `); + }, + }), + state: { + outsideDbError: null as string | null, + outsideClientError: null as string | null, + insideDbCount: 0, + insideClientAvailable: false, + }, + run: workflow(async (ctx) => { + await ctx.loop({ + name: "access", + run: async (loopCtx) => { + const actorLoopCtx = loopCtx as any; + let outsideDbError: string | null = null; + let outsideClientError: string | null = null; + + try { + // Accessing db outside a step should throw. + // biome-ignore lint/style/noUnusedExpressions: intentionally checking accessor. + actorLoopCtx.db; + } catch (error) { + outsideDbError = + error instanceof Error ? error.message : String(error); + } + + try { + actorLoopCtx.client(); + } catch (error) { + outsideClientError = + error instanceof Error ? error.message : String(error); + } + + await loopCtx.step("access-step", async () => { + await actorLoopCtx.db.execute( + `INSERT INTO workflow_access_log (created_at) VALUES (${Date.now()})`, + ); + const counts = (await actorLoopCtx.db.execute( + `SELECT COUNT(*) as count FROM workflow_access_log`, + )) as Array<{ count: number }>; + const client = actorLoopCtx.client(); + + actorLoopCtx.state.outsideDbError = outsideDbError; + actorLoopCtx.state.outsideClientError = outsideClientError; + actorLoopCtx.state.insideDbCount = counts[0]?.count ?? 0; + actorLoopCtx.state.insideClientAvailable = + typeof client.workflowQueueActor.getForId === "function"; + }); + + await loopCtx.sleep("idle", 25); + return Loop.continue(undefined); + }, + }); + }), + actions: { + getState: (c) => c.state, }, }); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts index 1fbaf5feb8..73847f8115 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts @@ -419,7 +419,10 @@ export class QueueManager { } /** Deletes messages matching the provided IDs. Returns the IDs that were removed. */ - async deleteMessagesById(ids: bigint[]): Promise { + async deleteMessagesById( + ids: bigint[], + options: { resolveWaiters?: boolean } = {}, + ): Promise { if (ids.length === 0) { return []; } @@ -431,10 +434,20 @@ export class QueueManager { if (toRemove.length === 0) { return []; } - await this.#removeMessages(toRemove, { resolveWaiters: true }); + await this.#removeMessages(toRemove, { + resolveWaiters: options.resolveWaiters ?? true, + }); return toRemove.map((entry) => entry.id); } + /** Completes a previously removed message by resolving its waiter, if one exists. */ + async completeById(messageId: bigint, response?: unknown): Promise { + this.#resolveCompletionWaiter(messageId, { + status: "completed", + response, + }); + } + async #drainMessages( nameSet: Set, count: number, diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts index cbce9c0c52..a572b1db11 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-workflow.ts @@ -23,12 +23,9 @@ export function runActorWorkflowTests(driverTestConfig: DriverTestConfig) { test("consumes queue messages via workflow listen", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); - const actor = client.workflowQueueActor.getOrCreate([ - "workflow-queue", - ]); + const actor = client.workflowQueueActor.getOrCreate(["workflow-queue"]); - const queueHandle = - actor.queue[workflowQueueName(WORKFLOW_QUEUE_NAME)]; + const queueHandle = actor.queue[workflowQueueName(WORKFLOW_QUEUE_NAME)]; await queueHandle.send({ hello: "world" }); await waitFor(driverTestConfig, 200); @@ -36,12 +33,45 @@ export function runActorWorkflowTests(driverTestConfig: DriverTestConfig) { expect(messages).toEqual([{ hello: "world" }]); }); - test("sleeps and resumes between ticks", async (c) => { + test("workflow listen supports completing wait sends", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); - const actor = client.workflowSleepActor.getOrCreate([ - "workflow-sleep", + const actor = client.workflowQueueActor.getOrCreate([ + "workflow-queue-wait", ]); + const result = await actor.sendAndWait({ value: 123 }); + expect(result).toEqual({ + status: "completed", + response: { echo: { value: 123 } }, + }); + }); + + test("db and client are step-only in workflow context", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + const actor = client.workflowAccessActor.getOrCreate([ + "workflow-access", + ]); + + let state = await actor.getState(); + for (let i = 0; i < 20 && state.insideDbCount === 0; i++) { + await waitFor(driverTestConfig, 50); + state = await actor.getState(); + } + + expect(state.outsideDbError).toBe( + "db is only available inside workflow steps", + ); + expect(state.outsideClientError).toBe( + "client is only available inside workflow steps", + ); + expect(state.insideDbCount).toBeGreaterThan(0); + expect(state.insideClientAvailable).toBe(true); + }); + + test("sleeps and resumes between ticks", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + const actor = client.workflowSleepActor.getOrCreate(["workflow-sleep"]); + const initial = await actor.getState(); await waitFor(driverTestConfig, 200); const next = await actor.getState(); diff --git a/rivetkit-typescript/packages/rivetkit/src/workflow/context.ts b/rivetkit-typescript/packages/rivetkit/src/workflow/context.ts index f3670889be..af39f3cc1b 100644 --- a/rivetkit-typescript/packages/rivetkit/src/workflow/context.ts +++ b/rivetkit-typescript/packages/rivetkit/src/workflow/context.ts @@ -1,5 +1,7 @@ import type { RunContext } from "@/actor/contexts/run"; -import type { AnyDatabaseProvider } from "@/actor/database"; +import type { Client } from "@/client/client"; +import type { Registry } from "@/registry"; +import type { AnyDatabaseProvider, InferDatabaseClient } from "@/actor/database"; import type { WorkflowContextInterface } from "@rivetkit/workflow-engine"; import type { BranchConfig, @@ -8,6 +10,7 @@ import type { LoopConfig, LoopResult, StepConfig, + WorkflowListenMessage, } from "@rivetkit/workflow-engine"; import { WORKFLOW_GUARD_KV_KEY } from "./constants"; @@ -42,27 +45,27 @@ export class ActorWorkflowContext< return this.#inner.abortSignal; } - async step( - nameOrConfig: string | Parameters[0], - run?: () => Promise, - ): Promise { + async step( + nameOrConfig: string | Parameters[0], + run?: () => Promise, + ): Promise { if (typeof nameOrConfig === "string") { if (!run) { throw new Error("Step run function missing"); } - return await this.#wrapActive(() => - this.#inner.step(nameOrConfig, () => - this.#withActorAccess(run), - ), - ); - } + return await this.#wrapActive(() => + this.#inner.step(nameOrConfig, () => + this.#withActorAccess(run), + ), + ); + } const stepConfig = nameOrConfig as StepConfig; const config: StepConfig = { ...stepConfig, run: () => this.#withActorAccess(stepConfig.run), }; return await this.#wrapActive(() => this.#inner.step(config)); - } + } async loop( name: string, @@ -103,7 +106,10 @@ export class ActorWorkflowContext< return this.#inner.sleepUntil(name, timestampMs); } - listen(name: string, messageName: string): Promise { + listen( + name: string, + messageName: string | string[], + ): Promise> { return this.#inner.listen(name, messageName); } @@ -212,6 +218,18 @@ export class ActorWorkflowContext< return this.#runCtx.vars as TVars extends never ? never : TVars; } + client>(): Client { + this.#ensureActorAccess("client"); + return this.#runCtx.client(); + } + + get db(): TDatabase extends never ? never : InferDatabaseClient { + this.#ensureActorAccess("db"); + return this.#runCtx.db as TDatabase extends never + ? never + : InferDatabaseClient; + } + get log() { return this.#runCtx.log; } diff --git a/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts b/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts index 4d5d57667e..7997b588cc 100644 --- a/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts @@ -29,6 +29,7 @@ function stripWorkflowKey(prefixed: Uint8Array): Uint8Array { class ActorWorkflowMessageDriver implements WorkflowMessageDriver { #actor: AnyActorInstance; #runCtx: RunContext; + #completionHandles = new Map Promise>(); constructor( actor: AnyActorInstance, @@ -42,16 +43,30 @@ class ActorWorkflowMessageDriver implements WorkflowMessageDriver { const queueMessages = await this.#runCtx.keepAwake( this.#actor.queueManager.getMessages(), ); + const now = Date.now(); const workflowMessages: Message[] = []; for (const queueMessage of queueMessages) { + if (queueMessage.inFlight || queueMessage.availableAt > now) { + continue; + } + const workflowName = stripWorkflowQueueName(queueMessage.name); if (!workflowName) continue; + const id = queueMessage.id.toString(); + this.#completionHandles.set(id, async (response?: unknown) => { + await this.#runCtx.keepAwake( + this.#actor.queueManager.completeById(queueMessage.id, response), + ); + }); workflowMessages.push({ - id: queueMessage.id.toString(), + id, name: workflowName, data: queueMessage.body, sentAt: queueMessage.createdAt, + complete: async (response?: unknown) => { + await this.completeMessage(id, response); + }, }); } @@ -88,10 +103,45 @@ class ActorWorkflowMessageDriver implements WorkflowMessageDriver { } const deleted = await this.#runCtx.keepAwake( - this.#actor.queueManager.deleteMessagesById(validIds), + this.#actor.queueManager.deleteMessagesById(validIds, { + resolveWaiters: false, + }), ); + + for (const id of deleted) { + const idString = id.toString(); + if (this.#completionHandles.has(idString)) { + continue; + } + this.#completionHandles.set(idString, async (response?: unknown) => { + await this.#runCtx.keepAwake( + this.#actor.queueManager.completeById(id, response), + ); + }); + } + return deleted.map((id) => id.toString()); } + + async completeMessage(messageId: string, response?: unknown): Promise { + const complete = this.#completionHandles.get(messageId); + if (complete) { + await complete(response); + this.#completionHandles.delete(messageId); + return; + } + + let parsedId: bigint; + try { + parsedId = BigInt(messageId); + } catch { + return; + } + + await this.#runCtx.keepAwake( + this.#actor.queueManager.completeById(parsedId, response), + ); + } } export class ActorWorkflowDriver implements EngineDriver { diff --git a/rivetkit-typescript/packages/workflow-engine/src/context.ts b/rivetkit-typescript/packages/workflow-engine/src/context.ts index 1fc0ae6cb4..3a7ecb9862 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/context.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/context.ts @@ -44,10 +44,12 @@ import type { Location, LoopConfig, LoopResult, + Message, RollbackContextInterface, StepConfig, Storage, WorkflowContextInterface, + WorkflowListenMessage, WorkflowMessageDriver, } from "./types.js"; import { sleep } from "./utils.js"; @@ -64,6 +66,8 @@ export const DEFAULT_LOOP_HISTORY_EVERY = 20; export const DEFAULT_LOOP_HISTORY_KEEP = 20; export const DEFAULT_STEP_TIMEOUT = 30000; // 30 seconds +const LISTEN_HISTORY_MESSAGE_MARKER = "__rivetWorkflowListenMessage"; + /** * Calculate backoff delay with exponential backoff. * Uses deterministic calculation (no jitter) for replay consistency. @@ -949,9 +953,24 @@ export class WorkflowContextImpl implements WorkflowContextInterface { // (SleepError/MessageWaitError), then on the next run, loadStorage() will // pick up the new message. This is intentional - no polling during execution. - async listen(name: string, messageName: string): Promise { - const messages = await this.listenN(name, messageName, 1); - return messages[0]; + async listen( + name: string, + messageName: string | string[], + ): Promise> { + this.assertNotInProgress(); + this.checkEvicted(); + + this.entryInProgress = true; + try { + const messages = await this.executeListenN(name, messageName, 1); + const message = messages[0]; + if (!message) { + throw new HistoryDivergedError("Expected message for listen()"); + } + return this.toListenMessage(message); + } finally { + this.entryInProgress = false; + } } async listenN( @@ -964,17 +983,24 @@ export class WorkflowContextImpl implements WorkflowContextInterface { this.entryInProgress = true; try { - return await this.executeListenN(name, messageName, limit); + const messages = await this.executeListenN(name, messageName, limit); + await Promise.all( + messages.map((message) => this.completeConsumedMessage(message)), + ); + return messages.map((message) => message.data as T); } finally { this.entryInProgress = false; } } - private async executeListenN( + private async executeListenN( name: string, - messageName: string, + messageName: string | string[], limit: number, - ): Promise { + ): Promise { + const messageNames = this.normalizeMessageNames(messageName); + const messageNameLabel = this.messageNamesLabel(messageNames); + // Check for duplicate name in current execution this.checkDuplicateName(name); @@ -995,7 +1021,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface { if (existingCount && existingCount.kind.type === "message") { // Replay: read all recorded messages const count = existingCount.kind.data.data as number; - const results: T[] = []; + const results: Message[] = []; for (let i = 0; i < count; i++) { const messageLocation = appendName( @@ -1014,7 +1040,12 @@ export class WorkflowContextImpl implements WorkflowContextInterface { existingMessage && existingMessage.kind.type === "message" ) { - results.push(existingMessage.kind.data.data as T); + results.push( + this.fromHistoryListenMessage( + existingMessage.kind.data.name, + existingMessage.kind.data.data, + ), + ); } } @@ -1025,7 +1056,7 @@ export class WorkflowContextImpl implements WorkflowContextInterface { const messages = await consumeMessages( this.storage, this.messageDriver, - messageName, + messageNames, limit, ); @@ -1039,7 +1070,10 @@ export class WorkflowContextImpl implements WorkflowContextInterface { ); const messageEntry = createEntry(messageLocation, { type: "message", - data: { name: messageName, data: messages[i].data }, + data: { + name: messages[i].name, + data: this.toHistoryListenMessage(messages[i]), + }, }); setEntry(this.storage, messageLocation, messageEntry); @@ -1050,17 +1084,117 @@ export class WorkflowContextImpl implements WorkflowContextInterface { // Record the count for replay const countEntry = createEntry(countLocation, { type: "message", - data: { name: `${messageName}:count`, data: messages.length }, + data: { + name: `${messageNameLabel}:count`, + data: messages.length, + }, }); setEntry(this.storage, countLocation, countEntry); await this.flushStorage(); - return messages.map((message) => message.data as T); + return messages; } // No messages found, throw to yield to scheduler - throw new MessageWaitError([messageName]); + throw new MessageWaitError(messageNames); + } + + private normalizeMessageNames(messageName: string | string[]): string[] { + const names = Array.isArray(messageName) ? messageName : [messageName]; + const deduped: string[] = []; + const seen = new Set(); + + for (const name of names) { + if (seen.has(name)) { + continue; + } + seen.add(name); + deduped.push(name); + } + + if (deduped.length === 0) { + throw new Error("listen() requires at least one message name"); + } + + return deduped; + } + + private messageNamesLabel(messageNames: string[]): string { + return messageNames.length === 1 + ? messageNames[0] + : messageNames.join("|"); + } + + private toListenMessage(message: Message): WorkflowListenMessage { + return { + id: message.id, + name: message.name, + body: message.data as T, + complete: async (response?: unknown) => { + if (message.complete) { + await message.complete(response); + return; + } + if (this.messageDriver.completeMessage) { + await this.messageDriver.completeMessage(message.id, response); + } + }, + }; + } + + private async completeConsumedMessage(message: Message): Promise { + if (message.complete) { + await message.complete(); + return; + } + if (message.id && this.messageDriver.completeMessage) { + await this.messageDriver.completeMessage(message.id); + } + } + + private toHistoryListenMessage(message: Message): unknown { + return { + [LISTEN_HISTORY_MESSAGE_MARKER]: 1, + id: message.id, + name: message.name, + body: message.data, + }; + } + + private fromHistoryListenMessage(name: string, value: unknown): Message { + if ( + typeof value === "object" && + value !== null && + (value as Record)[LISTEN_HISTORY_MESSAGE_MARKER] === 1 + ) { + const serialized = value as Record; + const id = + typeof serialized.id === "string" ? serialized.id : ""; + const serializedName = + typeof serialized.name === "string" ? serialized.name : name; + const complete = async (response?: unknown) => { + if (!id || !this.messageDriver.completeMessage) { + return; + } + await this.messageDriver.completeMessage(id, response); + }; + + return { + id, + name: serializedName, + data: serialized.body, + sentAt: 0, + complete, + }; + } + + return { + id: "", + name, + data: value, + sentAt: 0, + }; } async listenWithTimeout( @@ -1124,19 +1258,22 @@ export class WorkflowContextImpl implements WorkflowContextInterface { // Check for replay if (existingSleep && existingSleep.kind.type === "sleep") { const sleepData = existingSleep.kind.data; - if (sleepData.state === "completed") { return null; } if (sleepData.state === "interrupted") { - const existingMessage = - this.storage.history.entries.get(messageKey); + const existingMessage = this.storage.history.entries.get(messageKey); if ( existingMessage && existingMessage.kind.type === "message" ) { - return existingMessage.kind.data.data as T; + const replayedMessage = this.fromHistoryListenMessage( + existingMessage.kind.data.name, + existingMessage.kind.data.data, + ); + await this.completeConsumedMessage(replayedMessage); + return replayedMessage.data as T; } throw new HistoryDivergedError( "Expected message entry after interrupted sleep", @@ -1179,10 +1316,14 @@ export class WorkflowContextImpl implements WorkflowContextInterface { const messageEntry = createEntry(messageLocation, { type: "message", - data: { name: messageName, data: message.data }, + data: { + name: message.name, + data: this.toHistoryListenMessage(message), + }, }); setEntry(this.storage, messageLocation, messageEntry); await this.flushStorage(); + await this.completeConsumedMessage(message); return message.data as T; } @@ -1210,10 +1351,14 @@ export class WorkflowContextImpl implements WorkflowContextInterface { const messageEntry = createEntry(messageLocation, { type: "message", - data: { name: messageName, data: message.data }, + data: { + name: message.name, + data: this.toHistoryListenMessage(message), + }, }); setEntry(this.storage, messageLocation, messageEntry); await this.flushStorage(); + await this.completeConsumedMessage(message); return message.data as T; } @@ -1355,13 +1500,17 @@ export class WorkflowContextImpl implements WorkflowContextInterface { this.markVisited(messageKey); - const existingMessage = - this.storage.history.entries.get(messageKey); + const existingMessage = this.storage.history.entries.get(messageKey); if ( existingMessage && existingMessage.kind.type === "message" ) { - results.push(existingMessage.kind.data.data as T); + const replayedMessage = this.fromHistoryListenMessage( + existingMessage.kind.data.name, + existingMessage.kind.data.data, + ); + await this.completeConsumedMessage(replayedMessage); + results.push(replayedMessage.data as T); } } @@ -1401,10 +1550,14 @@ export class WorkflowContextImpl implements WorkflowContextInterface { ); const messageEntry = createEntry(messageLocation, { type: "message", - data: { name: messageName, data: message.data }, + data: { + name: message.name, + data: this.toHistoryListenMessage(message), + }, }); setEntry(this.storage, messageLocation, messageEntry); this.markVisited(locationToKey(this.storage, messageLocation)); + await this.completeConsumedMessage(message); results.push(message.data as T); } diff --git a/rivetkit-typescript/packages/workflow-engine/src/index.ts b/rivetkit-typescript/packages/workflow-engine/src/index.ts index 5e158f0016..e9f231a3c8 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/index.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/index.ts @@ -100,6 +100,7 @@ export type { WorkflowContextInterface, WorkflowFunction, WorkflowHandle, + WorkflowListenMessage, WorkflowMessageDriver, WorkflowResult, WorkflowRunMode, diff --git a/rivetkit-typescript/packages/workflow-engine/src/storage.ts b/rivetkit-typescript/packages/workflow-engine/src/storage.ts index c981e3e0b8..8c6fff7965 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/storage.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/storage.ts @@ -387,7 +387,7 @@ export async function addMessage( export async function consumeMessage( storage: Storage, messageDriver: WorkflowMessageDriver, - messageName: string, + messageName: string | string[], ): Promise { const messages = await consumeMessages( storage, @@ -409,15 +409,19 @@ export async function consumeMessage( export async function consumeMessages( storage: Storage, messageDriver: WorkflowMessageDriver, - messageName: string, + messageName: string | string[], limit: number, ): Promise { + const messageNameSet = new Set( + Array.isArray(messageName) ? messageName : [messageName], + ); + // Find all matching messages up to limit (don't modify memory yet) const toConsume: { message: Message; index: number }[] = []; let count = 0; for (let i = 0; i < storage.messages.length && count < limit; i++) { - if (storage.messages[i].name === messageName) { + if (messageNameSet.has(storage.messages[i].name)) { toConsume.push({ message: storage.messages[i], index: i }); count++; } diff --git a/rivetkit-typescript/packages/workflow-engine/src/types.ts b/rivetkit-typescript/packages/workflow-engine/src/types.ts index c21fd04821..9b2d918b35 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/types.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/types.ts @@ -195,6 +195,22 @@ export interface Message { name: string; data: unknown; sentAt: number; + /** + * Optional completion callback for queue-backed drivers. + * + * This is runtime-only and is not persisted. + */ + complete?: (response?: unknown) => Promise; +} + +/** + * Message handle returned by listen(). + */ +export interface WorkflowListenMessage { + id: string; + name: string; + body: T; + complete(response?: unknown): Promise; } /** @@ -264,6 +280,10 @@ export interface WorkflowMessageDriver { * Delete the specified messages and return the IDs that were successfully removed. */ deleteMessages(messageIds: string[]): Promise; + /** + * Optionally complete a previously consumed message with a response payload. + */ + completeMessage?(messageId: string, response?: unknown): Promise; } /** @@ -348,7 +368,10 @@ export interface WorkflowContextInterface { sleep(name: string, durationMs: number): Promise; sleepUntil(name: string, timestampMs: number): Promise; - listen(name: string, messageName: string): Promise; + listen( + name: string, + messageName: string | string[], + ): Promise>; listenN(name: string, messageName: string, limit: number): Promise; listenWithTimeout( name: string, diff --git a/rivetkit-typescript/packages/workflow-engine/tests/handle.test.ts b/rivetkit-typescript/packages/workflow-engine/tests/handle.test.ts index dad3433dc5..149da0baa0 100644 --- a/rivetkit-typescript/packages/workflow-engine/tests/handle.test.ts +++ b/rivetkit-typescript/packages/workflow-engine/tests/handle.test.ts @@ -18,7 +18,8 @@ for (const mode of modes) { it("should send messages via handle", async () => { const workflow = async (ctx: WorkflowContextInterface) => { - return await ctx.listen("wait", "message-name"); + const message = await ctx.listen("wait", "message-name"); + return message.body; }; const handle = runWorkflow("wf-1", workflow, undefined, driver, { diff --git a/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts b/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts index 1a5538a6cb..27da0e0467 100644 --- a/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts +++ b/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts @@ -6,6 +6,7 @@ import { runWorkflow, serializeMessage, type WorkflowContextInterface, + type WorkflowMessageDriver, } from "../src/testing.js"; function buildMessagePayload(name: string, data: string, id = generateId()) { @@ -34,7 +35,7 @@ for (const mode of modes) { "wait-message", "my-message", ); - return message; + return message.body; }; const handle = runWorkflow("wf-1", workflow, undefined, driver, { @@ -54,7 +55,50 @@ for (const mode of modes) { expect(result.output).toBe("payload"); }); - it("should consume pending messages", async () => { + it("should listen for any message in a name set", async () => { + const workflow = async (ctx: WorkflowContextInterface) => { + const message = await ctx.listen("wait-many", [ + "first", + "second", + ]); + return { name: message.name, body: message.body }; + }; + + const handle = runWorkflow("wf-1", workflow, undefined, driver, { + mode, + }); + + if (mode === "yield") { + const result1 = await handle.result; + expect(result1.state).toBe("sleeping"); + expect(result1.waitingForMessages).toEqual(["first", "second"]); + + await handle.message("second", "payload"); + const result2 = await runWorkflow( + "wf-1", + workflow, + undefined, + driver, + { mode }, + ).result; + expect(result2.state).toBe("completed"); + expect(result2.output).toEqual({ + name: "second", + body: "payload", + }); + return; + } + + await handle.message("second", "payload"); + const result = await handle.result; + expect(result.state).toBe("completed"); + expect(result.output).toEqual({ + name: "second", + body: "payload", + }); + }); + + it("should consume pending messages", async () => { const messageId = generateId(); await driver.set( buildMessageKey(messageId), @@ -68,7 +112,63 @@ for (const mode of modes) { "wait-message", "my-message", ); - return message; + return message.body; + }; + + const result = await runWorkflow( + "wf-1", + workflow, + undefined, + driver, + { + mode, + }, + ).result; + + expect(result.state).toBe("completed"); + expect(result.output).toBe("hello"); + }); + + it("listen should return a completable message handle", async () => { + const completions: Array<{ id: string; response?: unknown }> = []; + const pending = [ + buildMessagePayload("my-message", "hello", "msg-1") as { + id: string; + name: string; + data: unknown; + sentAt: number; + complete?: (response?: unknown) => Promise; + }, + ]; + + const messageDriver: WorkflowMessageDriver = { + async loadMessages() { + return pending.map((message) => ({ + ...message, + complete: async (response?: unknown) => { + completions.push({ id: message.id, response }); + }, + })); + }, + async addMessage(message) { + pending.push(message); + }, + async deleteMessages(messageIds) { + const deleted = new Set(messageIds); + const remaining = pending.filter( + (message) => !deleted.has(message.id), + ); + pending.length = 0; + pending.push(...remaining); + return messageIds; + }, + }; + driver.messageDriver = messageDriver; + + const workflow = async (ctx: WorkflowContextInterface) => { + const message = await ctx.listen("wait-message", "my-message"); + await message.complete({ ok: true }); + return message.body; }; const result = await runWorkflow( @@ -83,6 +183,7 @@ for (const mode of modes) { expect(result.state).toBe("completed"); expect(result.output).toBe("hello"); + expect(completions).toEqual([{ id: "msg-1", response: { ok: true } }]); }); it("should collect multiple messages with listenN", async () => { @@ -280,7 +381,8 @@ for (const mode of modes) { return "ready"; }); - return await ctx.listen("wait", "mid"); + const message = await ctx.listen("wait", "mid"); + return message.body; }; const handle = runWorkflow("wf-1", workflow, undefined, driver, {