diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3192d6c6ff..19d086d746 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -4281,6 +4281,9 @@ importers: '@rivetkit/workflow-engine': specifier: workspace:* version: link:../workflow-engine + '@standard-schema/spec': + specifier: ^1.0.0 + version: 1.0.0 cbor-x: specifier: ^1.6.0 version: 1.6.0 diff --git a/rivetkit-typescript/packages/rivetkit/package.json b/rivetkit-typescript/packages/rivetkit/package.json index b0a7195044..dccaf71807 100644 --- a/rivetkit-typescript/packages/rivetkit/package.json +++ b/rivetkit-typescript/packages/rivetkit/package.json @@ -217,6 +217,7 @@ "@rivetkit/sqlite-vfs": "workspace:*", "@rivetkit/traces": "workspace:*", "@rivetkit/virtual-websocket": "workspace:*", + "@standard-schema/spec": "^1.0.0", "cbor-x": "^1.6.0", "get-port": "^7.1.0", "hono": "^4.7.0", diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index 6096342af5..1aa55f5984 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -20,6 +20,7 @@ import type { WebSocketContext, } from "./contexts"; import type { AnyDatabaseProvider } from "./database"; +import type { SchemaConfig } from "./schema"; export interface ActorTypes< TState, @@ -57,9 +58,10 @@ export interface RunInspectorConfig { const WorkflowInspectorConfigSchema = z.object({ getHistory: zFunction["getHistory"]>(), - onHistoryUpdated: zFunction< - NonNullable["onHistoryUpdated"]> - >().optional(), + onHistoryUpdated: + zFunction< + NonNullable["onHistoryUpdated"]> + >().optional(), }); const RunInspectorConfigSchema = z @@ -129,6 +131,8 @@ export const ActorConfigSchema = z onRequest: zFunction().optional(), onWebSocket: zFunction().optional(), actions: z.record(z.string(), zFunction()).default(() => ({})), + events: z.record(z.string(), z.any()).optional(), + queues: z.record(z.string(), z.any()).optional(), state: z.any().optional(), createState: zFunction().optional(), connState: z.any().optional(), @@ -219,11 +223,13 @@ type CreateState< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig, + TQueues extends SchemaConfig, > = | { state: TState } | { createState: ( - c: CreateContext, + c: CreateContext, input: TInput, ) => TState | Promise; } @@ -241,11 +247,20 @@ type CreateConnState< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig, + TQueues extends SchemaConfig, > = | { connState: TConnState } | { createConnState: ( - c: CreateConnStateContext, + c: CreateConnStateContext< + TState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, params: TConnParams, ) => TConnState | Promise; } @@ -264,6 +279,8 @@ type CreateVars< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig, + TQueues extends SchemaConfig, > = | { /** @@ -276,7 +293,13 @@ type CreateVars< * @experimental */ createVars: ( - c: CreateVarsContext, + c: CreateVarsContext< + TState, + TInput, + TDatabase, + TEvents, + TQueues + >, driverCtx: any, ) => TVars | Promise; } @@ -289,6 +312,8 @@ export interface Actions< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > { [Action: string]: ( c: ActionContext< @@ -297,7 +322,9 @@ export interface Actions< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ...args: any[] ) => any; @@ -320,13 +347,17 @@ interface BaseActorConfig< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig, + TQueues extends SchemaConfig, TActions extends Actions< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, > { /** @@ -336,7 +367,7 @@ interface BaseActorConfig< * This is called before any other lifecycle hooks. */ onCreate?: ( - c: CreateContext, + c: CreateContext, input: TInput, ) => void | Promise; @@ -350,7 +381,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ) => void | Promise; @@ -369,7 +402,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ) => void | Promise; @@ -390,7 +425,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ) => void | Promise; @@ -425,7 +462,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ) => void | Promise) | RunConfig; @@ -448,7 +487,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, newState: TState, ) => void; @@ -464,7 +505,14 @@ interface BaseActorConfig< * @throws Throw an error to reject the connection */ onBeforeConnect?: ( - c: BeforeConnectContext, + c: BeforeConnectContext< + TState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, params: TConnParams, ) => void | Promise; @@ -484,9 +532,20 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues + >, + conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues >, - conn: Conn, ) => void | Promise; /** @@ -505,9 +564,20 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues + >, + conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues >, - conn: Conn, ) => void | Promise; /** @@ -529,7 +599,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, name: string, args: unknown[], @@ -554,7 +626,9 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, request: Request, ) => Response | Promise; @@ -576,12 +650,24 @@ interface BaseActorConfig< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, websocket: UniversalWebSocket, ) => void | Promise; actions?: TActions; + + /** + * Schema map for events broadcasted by this actor. + */ + events?: TEvents; + + /** + * Schema map for queue payloads sent by this actor. + */ + queues?: TQueues; } type ActorDatabaseConfig = @@ -603,9 +689,13 @@ export type ActorConfig< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > = Omit< z.infer, | "actions" + | "events" + | "queues" | "onCreate" | "onDestroy" | "onWake" @@ -633,11 +723,49 @@ export type ActorConfig< TVars, TInput, TDatabase, - Actions + TEvents, + TQueues, + Actions< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > + > & + CreateState< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > & + CreateConnState< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > & + CreateVars< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues > & - CreateState & - CreateConnState & - CreateVars & ActorDatabaseConfig; // See description on `ActorConfig` @@ -648,13 +776,17 @@ export type ActorConfigInput< TVars = undefined, TInput = undefined, TDatabase extends AnyDatabaseProvider = undefined, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, TActions extends Actions< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > = Record, > = { types?: ActorTypes< @@ -668,6 +800,8 @@ export type ActorConfigInput< } & Omit< z.input, | "actions" + | "events" + | "queues" | "onCreate" | "onDestroy" | "onWake" @@ -695,11 +829,40 @@ export type ActorConfigInput< TVars, TInput, TDatabase, + TEvents, + TQueues, TActions > & - CreateState & - CreateConnState & - CreateVars & + CreateState< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > & + CreateConnState< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > & + CreateVars< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > & ActorDatabaseConfig; // For testing type definitions: @@ -710,13 +873,17 @@ export function test< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig, + TQueues extends SchemaConfig, TActions extends Actions< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, >( input: ActorConfigInput< @@ -726,16 +893,29 @@ export function test< TVars, TInput, TDatabase, + TEvents, + TQueues, TActions >, -): ActorConfig { +): ActorConfig< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues +> { const config = ActorConfigSchema.parse(input) as ActorConfig< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >; return config; } @@ -959,6 +1139,14 @@ export const DocActorConfigSchema = z .describe( "Map of action name to handler function. Defaults to an empty object.", ), + events: z + .record(z.string(), z.unknown()) + .optional() + .describe("Map of event names to schemas."), + queues: z + .record(z.string(), z.unknown()) + .optional() + .describe("Map of queue names to schemas."), options: DocActorOptionsSchema.optional(), }) .describe("Actor configuration passed to the actor() function."); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index 7c7c7dcb9b..fdd70ea526 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -10,15 +10,21 @@ import { } from "@/schemas/client-protocol-zod/mod"; import { bufferToArrayBuffer } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; -import { InternalError } from "../errors"; +import { EventPayloadInvalid, InternalError } from "../errors"; import type { ActorInstance } from "../instance/mod"; import { CachedSerializer } from "../protocol/serde"; +import { + type InferEventArgs, + type InferSchemaMap, + type SchemaConfig, + validateSchemaSync, +} from "../schema"; import type { ConnDriver } from "./driver"; import { type ConnDataInput, StateManager } from "./state-manager"; export type ConnId = string; -export type AnyConn = Conn; +export type AnyConn = Conn; export const CONN_CONNECTED_SYMBOL = Symbol("connected"); export const CONN_SPEAKS_RIVETKIT_SYMBOL = Symbol("speaksRivetKit"); @@ -34,10 +40,19 @@ export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage"); * * @see {@link https://rivet.dev/docs/connections|Connection Documentation} */ -export class Conn { - #actor: ActorInstance; +export class Conn< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, +> { + #actor: ActorInstance; - get [CONN_ACTOR_SYMBOL](): ActorInstance { + get [CONN_ACTOR_SYMBOL](): ActorInstance { return this.#actor; } @@ -122,7 +137,7 @@ export class Conn { * @protected */ constructor( - actor: ActorInstance, + actor: ActorInstance, data: ConnDataInput, ) { this.#actor = actor; @@ -159,6 +174,14 @@ export class Conn { * @param args - The arguments for the event. * @see {@link https://rivet.dev/docs/events|Events Documentation} */ + send( + eventName: K, + ...args: InferEventArgs[K]> + ): void; + send( + eventName: keyof E extends never ? string : never, + ...args: unknown[] + ): void; send(eventName: string, ...args: unknown[]) { this.#assertConnected(); if (!this[CONN_SPEAKS_RIVETKIT_SYMBOL]) { @@ -168,11 +191,27 @@ export class Conn { connType: this[CONN_DRIVER_SYMBOL]?.type, }); } + + const payload = args.length === 1 ? args[0] : args; + const result = validateSchemaSync( + this.#actor.config.events, + eventName as keyof E & string, + payload, + ); + if (!result.success) { + throw new EventPayloadInvalid(eventName, result.issues); + } + const eventArgs = + args.length === 1 + ? [result.data] + : Array.isArray(result.data) + ? (result.data as unknown[]) + : args; this.#actor.emitTraceEvent("message.send", { "rivet.event.name": eventName, "rivet.conn.id": this.id, }); - const eventData = { name: eventName, args }; + const eventData = { name: eventName, args: eventArgs }; this[CONN_SEND_MESSAGE_SYMBOL]( new CachedSerializer( eventData, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts index 63156343c3..753cabd59f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts @@ -39,7 +39,7 @@ export type ConnData = * Handles automatic state change detection for connection-specific state. */ export class StateManager { - #conn: Conn; + #conn: Conn; /** * Data representing this connection. @@ -50,7 +50,7 @@ export class StateManager { #data!: ConnData; constructor( - conn: Conn, + conn: Conn, data: ConnDataInput, ) { this.#conn = conn; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts index 5c1cccc022..b37508db74 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts @@ -2,6 +2,7 @@ import type { Conn } from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; import type { ActorInstance } from "../instance/mod"; +import type { SchemaConfig } from "../schema"; import { ConnContext } from "./base/conn"; /** @@ -14,26 +15,33 @@ export class ActionContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ConnContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} /** * Extracts the ActionContext type from an ActorDefinition. */ -export type ActionContextOf = AD extends ActorDefinition< - infer S, - infer CP, - infer CS, - infer V, - infer I, - infer DB extends AnyDatabaseProvider, - any -> - ? ActionContext - : never; +export type ActionContextOf = + AD extends ActorDefinition< + infer S, + infer CP, + infer CS, + infer V, + infer I, + infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, + any + > + ? ActionContext + : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts index 4aba6dcc5a..c89aad6744 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts @@ -5,14 +5,21 @@ import type { Registry } from "@/registry"; import type { Conn, ConnId } from "../../conn/mod"; import type { AnyDatabaseProvider, InferDatabaseClient } from "../../database"; import type { ActorDefinition, AnyActorDefinition } from "../../definition"; +import * as errors from "../../errors"; +import { ActorKv } from "../../instance/kv"; import type { ActorInstance, AnyActorInstance, SaveStateOptions, } from "../../instance/mod"; -import { ActorKv } from "../../instance/kv"; import { ActorQueue } from "../../instance/queue"; import type { Schedule } from "../../schedule"; +import { + type InferEventArgs, + type InferSchemaMap, + type SchemaConfig, + validateSchemaSync, +} from "../../schema"; export const ACTOR_CONTEXT_INTERNAL_SYMBOL = Symbol.for( "rivetkit.actorContextInternal", @@ -28,6 +35,8 @@ export class ActorContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > { [ACTOR_CONTEXT_INTERNAL_SYMBOL]!: AnyActorInstance; #actor: ActorInstance< @@ -36,11 +45,22 @@ export class ActorContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >; #kv: ActorKv | undefined; #queue: - | ActorQueue + | ActorQueue< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > | undefined; constructor( @@ -50,7 +70,9 @@ export class ActorContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ) { this.#actor = actor; @@ -93,9 +115,36 @@ export class ActorContext< * @param name - The name of the event. * @param args - The arguments to send with the event. */ - broadcast>(name: string, ...args: Args): void { + broadcast( + name: K, + ...args: InferEventArgs[K]> + ): void; + broadcast( + name: keyof TEvents extends never ? string : never, + ...args: Array + ): void; + broadcast(name: string, ...args: Array): void { + const payload = args.length === 1 ? args[0] : args; + const result = validateSchemaSync( + this.#actor.config.events, + name as keyof TEvents & string, + payload, + ); + if (!result.success) { + throw new errors.EventPayloadInvalid(name, result.issues); + } + if (args.length === 1) { + this.#actor.eventManager.broadcast(name, result.data); + return; + } + if (Array.isArray(result.data)) { + this.#actor.eventManager.broadcast( + name, + ...(result.data as unknown[]), + ); + return; + } this.#actor.eventManager.broadcast(name, ...args); - return; } /** @@ -114,7 +163,9 @@ export class ActorContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > { if (!this.#queue) { this.#queue = new ActorQueue( @@ -165,7 +216,16 @@ export class ActorContext< */ get conns(): Map< ConnId, - Conn + Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > > { return this.#actor.conns; } @@ -258,7 +318,9 @@ export type ActorContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? ActorContext + ? ActorContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts index a5f5b002df..f22d10e4e5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts @@ -1,9 +1,7 @@ import type { AnyDatabaseProvider } from "../../database"; -import type { - ActorDefinition, - AnyActorDefinition, -} from "../../definition"; +import type { ActorDefinition, AnyActorDefinition } from "../../definition"; import type { ActorInstance } from "../../instance/mod"; +import type { SchemaConfig } from "../../schema"; import { ActorContext } from "./actor"; /** @@ -15,7 +13,18 @@ export abstract class ConnInitContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, -> extends ActorContext { + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, +> extends ActorContext< + TState, + never, + never, + TVars, + TInput, + TDatabase, + TEvents, + TQueues +> { /** * The incoming request that initiated the connection. * May be undefined for connections initiated without a direct HTTP request. @@ -26,7 +35,16 @@ export abstract class ConnInitContext< * @internal */ constructor( - actor: ActorInstance, + actor: ActorInstance< + TState, + any, + any, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, request: Request | undefined, ) { super(actor as any); @@ -42,7 +60,9 @@ export type ConnInitContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? ConnInitContext + ? ConnInitContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts index 06be2482c3..24f9858612 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts @@ -1,10 +1,8 @@ import type { Conn } from "../../conn/mod"; import type { AnyDatabaseProvider } from "../../database"; -import type { - ActorDefinition, - AnyActorDefinition, -} from "../../definition"; +import type { ActorDefinition, AnyActorDefinition } from "../../definition"; import type { ActorInstance } from "../../instance/mod"; +import type { SchemaConfig } from "../../schema"; import { ActorContext } from "./actor"; /** @@ -18,13 +16,17 @@ export abstract class ConnContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > { /** * @internal @@ -36,7 +38,9 @@ export abstract class ConnContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, public readonly conn: Conn< TState, @@ -44,21 +48,26 @@ export abstract class ConnContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >, ) { super(actor); } } -export type ConnContextOf = AD extends ActorDefinition< - infer S, - infer CP, - infer CS, - infer V, - infer I, - infer DB extends AnyDatabaseProvider, - any -> - ? ConnContext - : never; +export type ConnContextOf = + AD extends ActorDefinition< + infer S, + infer CP, + infer CS, + infer V, + infer I, + infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, + any + > + ? ConnContext + : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-action-response.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-action-response.ts index 309ba14bdb..3cb8551a22 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-action-response.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-action-response.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -12,13 +13,17 @@ export class BeforeActionResponseContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} export type BeforeActionResponseContextOf = @@ -29,7 +34,9 @@ export type BeforeActionResponseContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? BeforeActionResponseContext + ? BeforeActionResponseContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-connect.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-connect.ts index 4a54b761f2..43213c91d7 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-connect.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/before-connect.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ConnInitContext } from "./base/conn-init"; /** @@ -10,7 +11,9 @@ export class BeforeConnectContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, -> extends ConnInitContext {} + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, +> extends ConnInitContext {} export type BeforeConnectContextOf = AD extends ActorDefinition< @@ -20,7 +23,9 @@ export type BeforeConnectContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? BeforeConnectContext + ? BeforeConnectContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/connect.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/connect.ts index 2c1472ac92..47f5eb570e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/connect.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/connect.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ConnContext } from "./base/conn"; /** @@ -12,13 +13,17 @@ export class ConnectContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ConnContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} export type ConnectContextOf = @@ -29,7 +34,9 @@ export type ConnectContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? ConnectContext + ? ConnectContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts index 9953be1470..504c84e57c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ConnInitContext } from "./base/conn-init"; /** @@ -11,7 +12,9 @@ export class CreateConnStateContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, -> extends ConnInitContext {} + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, +> extends ConnInitContext {} export type CreateConnStateContextOf = AD extends ActorDefinition< @@ -21,7 +24,9 @@ export type CreateConnStateContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? CreateConnStateContext + ? CreateConnStateContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-vars.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-vars.ts index ef611cefb6..482f3b71ff 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-vars.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-vars.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -9,8 +10,18 @@ export class CreateVarsContext< TState, TInput, TDatabase extends AnyDatabaseProvider, -> extends ActorContext {} - + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, +> extends ActorContext< + TState, + never, + never, + never, + TInput, + TDatabase, + TEvents, + TQueues +> {} export type CreateVarsContextOf = AD extends ActorDefinition< @@ -20,7 +31,9 @@ export type CreateVarsContextOf = any, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? CreateVarsContext + ? CreateVarsContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create.ts index 17f860781d..a43090dba6 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -9,8 +10,18 @@ export class CreateContext< TState, TInput, TDatabase extends AnyDatabaseProvider, -> extends ActorContext {} - + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, +> extends ActorContext< + TState, + never, + never, + never, + TInput, + TDatabase, + TEvents, + TQueues +> {} export type CreateContextOf = AD extends ActorDefinition< @@ -20,7 +31,9 @@ export type CreateContextOf = any, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? CreateContext + ? CreateContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/destroy.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/destroy.ts index c6304e1a84..bc4a82c2d0 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/destroy.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/destroy.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -12,16 +13,19 @@ export class DestroyContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} - export type DestroyContextOf = AD extends ActorDefinition< infer S, @@ -30,7 +34,9 @@ export type DestroyContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? DestroyContext + ? DestroyContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/disconnect.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/disconnect.ts index 42743546f6..a1237e484f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/disconnect.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/disconnect.ts @@ -1,6 +1,7 @@ import type { Conn } from "../conn/mod"; -import type { ActorDefinition, AnyActorDefinition } from "../definition"; import type { AnyDatabaseProvider } from "../database"; +import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -13,16 +14,19 @@ export class DisconnectContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} - export type DisconnectContextOf = AD extends ActorDefinition< infer S, @@ -31,7 +35,9 @@ export type DisconnectContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? DisconnectContext + ? DisconnectContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts index 632da2407c..7604d6c079 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts @@ -1,7 +1,8 @@ import type { Conn } from "../conn/mod"; -import type { ActorDefinition, AnyActorDefinition } from "../definition"; import type { AnyDatabaseProvider } from "../database"; +import type { ActorDefinition, AnyActorDefinition } from "../definition"; import type { ActorInstance } from "../instance/mod"; +import type { SchemaConfig } from "../schema"; import { ConnContext } from "./base/conn"; /** @@ -14,13 +15,17 @@ export class RequestContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ConnContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > { /** * The incoming HTTP request. @@ -38,9 +43,20 @@ export class RequestContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues + >, + conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues >, - conn: Conn, request?: Request, ) { super(actor, conn); @@ -48,7 +64,6 @@ export class RequestContext< } } - export type RequestContextOf = AD extends ActorDefinition< infer S, @@ -57,7 +72,9 @@ export type RequestContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? RequestContext + ? RequestContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/run.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/run.ts index 63aa2a3ad1..3af4dc07eb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/run.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/run.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -17,13 +18,17 @@ export class RunContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} export type RunContextOf = @@ -34,7 +39,9 @@ export type RunContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? RunContext + ? RunContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/sleep.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/sleep.ts index 1614d1927c..659ec4f672 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/sleep.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/sleep.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -12,16 +13,19 @@ export class SleepContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} - export type SleepContextOf = AD extends ActorDefinition< infer S, @@ -30,7 +34,9 @@ export type SleepContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? SleepContext + ? SleepContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/state-change.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/state-change.ts index d323c25839..d3635e9b41 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/state-change.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/state-change.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -12,16 +13,19 @@ export class StateChangeContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} - export type StateChangeContextOf = AD extends ActorDefinition< infer S, @@ -30,7 +34,9 @@ export type StateChangeContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? StateChangeContext + ? StateChangeContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/wake.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/wake.ts index b2c4402a2e..26886ebadb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/wake.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/wake.ts @@ -1,5 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; +import type { SchemaConfig } from "../schema"; import { ActorContext } from "./base/actor"; /** @@ -12,16 +13,19 @@ export class WakeContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ActorContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > {} - export type WakeContextOf = AD extends ActorDefinition< infer S, @@ -30,7 +34,9 @@ export type WakeContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? WakeContext + ? WakeContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts index b274953cd9..215f3f8651 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts @@ -1,7 +1,8 @@ import type { Conn } from "../conn/mod"; -import type { ActorDefinition, AnyActorDefinition } from "../definition"; import type { AnyDatabaseProvider } from "../database"; +import type { ActorDefinition, AnyActorDefinition } from "../definition"; import type { ActorInstance } from "../instance/mod"; +import type { SchemaConfig } from "../schema"; import { ConnContext } from "./base/conn"; /** @@ -14,13 +15,17 @@ export class WebSocketContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, > extends ConnContext< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues > { /** * The incoming HTTP request that initiated the WebSocket upgrade. @@ -38,9 +43,20 @@ export class WebSocketContext< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues + >, + conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues >, - conn: Conn, request?: Request, ) { super(actor, conn); @@ -48,7 +64,6 @@ export class WebSocketContext< } } - export type WebSocketContextOf = AD extends ActorDefinition< infer S, @@ -57,7 +72,9 @@ export type WebSocketContextOf = infer V, infer I, infer DB extends AnyDatabaseProvider, + infer E extends SchemaConfig, + infer Q extends SchemaConfig, any > - ? WebSocketContext + ? WebSocketContext : never; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts index 53a2fc3111..f4f1a52b8f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts @@ -1,7 +1,10 @@ import type { RegistryConfig } from "@/registry/config"; +import { DeepMutable } from "@/utils"; import type { Actions, ActorConfig } from "./config"; +import type { ActionContextOf, ActorContext } from "./contexts"; import type { AnyDatabaseProvider } from "./database"; import { ActorInstance } from "./instance/mod"; +import type { SchemaConfig } from "./schema"; export type AnyActorDefinition = ActorDefinition< any, @@ -10,6 +13,8 @@ export type AnyActorDefinition = ActorDefinition< any, any, any, + any, + any, any >; @@ -20,19 +25,30 @@ export class ActorDefinition< V, I, DB extends AnyDatabaseProvider, - R extends Actions, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, + R extends Actions = Actions< + S, + CP, + CS, + V, + I, + DB, + E, + Q + >, > { - #config: ActorConfig; + #config: ActorConfig; - constructor(config: ActorConfig) { + constructor(config: ActorConfig) { this.#config = config; } - get config(): ActorConfig { + get config(): ActorConfig { return this.#config; } - instantiate(): ActorInstance { + instantiate(): ActorInstance { return new ActorInstance(this.#config); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts b/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts index 6291566e24..086be1a62c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts @@ -221,35 +221,24 @@ export class QueueMessageInvalid extends ActorError { } } -export class QueueCompleteNotAllowed extends ActorError { - constructor() { +export class EventPayloadInvalid extends ActorError { + constructor(name: string, issues?: unknown[]) { super( - "queue", - "complete_not_allowed", - "Queue message completion is only allowed when wait is enabled.", - { public: true }, + "event", + "invalid_payload", + `Event payload failed validation for '${name}'.`, + { public: true, metadata: { name, issues } }, ); } } -export class QueueMessagePending extends ActorError { - constructor() { +export class QueuePayloadInvalid extends ActorError { + constructor(name: string, issues?: unknown[]) { super( "queue", - "message_pending", - "Queue message is already pending completion.", - { public: true }, - ); - } -} - -export class QueueAlreadyCompleted extends ActorError { - constructor() { - super( - "queue", - "already_completed", - "Queue message has already been completed.", - { public: true }, + "invalid_payload", + `Queue payload failed validation for '${name}'.`, + { public: true, metadata: { name, issues } }, ); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts index fb0aa16af7..c07ad604c2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -30,6 +30,7 @@ import { } from "../contexts"; import type { AnyDatabaseProvider } from "../database"; import { CachedSerializer } from "../protocol/serde"; +import type { SchemaConfig } from "../schema"; import { deadline } from "../utils"; import { makeConnKey } from "./keys"; import type { ActorInstance } from "./mod"; @@ -44,22 +45,24 @@ export class ConnectionManager< V, I, DB extends AnyDatabaseProvider, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, > { - #actor: ActorInstance; - #connections = new Map>(); + #actor: ActorInstance; + #connections = new Map>(); /** Connections that have had their state changed and need to be persisted. */ #connsWithPersistChanged = new Set(); - constructor(actor: ActorInstance) { + constructor(actor: ActorInstance) { this.#actor = actor; } - get connections(): Map> { + get connections(): Map> { return this.#connections; } - getConnForId(id: string): Conn | undefined { + getConnForId(id: string): Conn | undefined { return this.#connections.get(id); } @@ -71,7 +74,7 @@ export class ConnectionManager< this.#connsWithPersistChanged.clear(); } - markConnWithPersistChanged(conn: Conn) { + markConnWithPersistChanged(conn: Conn) { invariant( conn.isHibernatable, "cannot mark non-hibernatable conn for persist", @@ -100,7 +103,7 @@ export class ConnectionManager< requestHeaders: Record | undefined, isHibernatable: boolean, isRestoringHibernatable: boolean, - ): Promise> { + ): Promise> { this.#actor.assertReady(); // TODO: Add back @@ -169,7 +172,7 @@ export class ConnectionManager< } // Create connection instance - const conn = new Conn(this.#actor, connData); + const conn = new Conn(this.#actor, connData); conn[CONN_DRIVER_SYMBOL] = driver; return conn; @@ -183,7 +186,7 @@ export class ConnectionManager< * be messed up and cause race conditions that can drop WebSocket messages. * So all async work in prepareConn. */ - connectConn(conn: Conn) { + connectConn(conn: Conn) { invariant(!this.#connections.has(conn.id), "conn already connected"); this.#connections.set(conn.id, conn); @@ -236,7 +239,9 @@ export class ConnectionManager< } } - #reconnectHibernatableConn(driver: ConnDriver): Conn { + #reconnectHibernatableConn( + driver: ConnDriver, + ): Conn { invariant(driver.hibernatable, "missing requestIdBuf"); const existingConn = this.findHibernatableConn( driver.hibernatable.gatewayId, @@ -271,7 +276,7 @@ export class ConnectionManager< return existingConn; } - #disconnectExistingDriver(conn: Conn) { + #disconnectExistingDriver(conn: Conn) { const driver = conn[CONN_DRIVER_SYMBOL]; if (driver?.disconnect) { driver.disconnect( @@ -287,7 +292,7 @@ export class ConnectionManager< * * This is called by `Conn.disconnect`. This should not call `Conn.disconnect.` */ - async connDisconnected(conn: Conn) { + async connDisconnected(conn: Conn) { // Remove from tracking this.#connections.delete(conn.id); @@ -400,7 +405,7 @@ export class ConnectionManager< request: Request | undefined, requestPath: string | undefined, requestHeaders: Record | undefined, - ): Promise> { + ): Promise> { const conn = await this.prepareConn( driver, params, @@ -422,7 +427,7 @@ export class ConnectionManager< restoreConnections(connections: PersistedConn[]) { for (const connPersist of connections) { // Create connection instance - const conn = new Conn(this.#actor, { + const conn = new Conn(this.#actor, { hibernatable: connPersist, }); this.#connections.set(conn.id, conn); @@ -448,7 +453,7 @@ export class ConnectionManager< findHibernatableConn( gatewayIdBuf: ArrayBuffer, requestIdBuf: ArrayBuffer, - ): Conn | undefined { + ): Conn | undefined { return Array.from(this.#connections.values()).find((conn) => { const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; const h = connStateManager.hibernatableDataRaw; @@ -471,8 +476,7 @@ export class ConnectionManager< "actor.createConnState", undefined, () => { - const dataOrPromise = - createConnState!(ctx, params); + const dataOrPromise = createConnState!(ctx, params); if (dataOrPromise instanceof Promise) { return deadline( dataOrPromise, @@ -491,7 +495,7 @@ export class ConnectionManager< ); } - #callOnConnect(conn: Conn) { + #callOnConnect(conn: Conn) { const attributes = { "rivet.conn.id": conn.id, "rivet.conn.type": conn[CONN_DRIVER_SYMBOL]?.type, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 079fd6eace..fd59658cfb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -18,17 +18,30 @@ import { import type { AnyDatabaseProvider } from "../database"; import * as errors from "../errors"; import { CachedSerializer } from "../protocol/serde"; +import type { SchemaConfig } from "../schema"; import type { ActorInstance } from "./mod"; /** * Manages event subscriptions and broadcasting for actor instances. * Handles subscription tracking and efficient message distribution to connected clients. */ -export class EventManager { - #actor: ActorInstance; - #subscriptionIndex = new Map>>(); +export class EventManager< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, +> { + #actor: ActorInstance; + #subscriptionIndex = new Map< + string, + Set> + >(); - constructor(actor: ActorInstance) { + constructor(actor: ActorInstance) { this.#actor = actor; } @@ -43,7 +56,7 @@ export class EventManager { */ addSubscription( eventName: string, - connection: Conn, + connection: Conn, fromPersist: boolean, ) { // Check if already subscribed @@ -94,7 +107,7 @@ export class EventManager { */ removeSubscription( eventName: string, - connection: Conn, + connection: Conn, fromRemoveConn: boolean, ) { // Check if subscription exists @@ -241,7 +254,7 @@ export class EventManager { */ getSubscribers( eventName: string, - ): Set> | undefined { + ): Set> | undefined { return this.#subscriptionIndex.get(eventName); } @@ -264,7 +277,7 @@ export class EventManager { * * @param connection - The connection to clear subscriptions for */ - clearConnectionSubscriptions(connection: Conn) { + clearConnectionSubscriptions(connection: Conn) { for (const eventName of [...connection.subscriptions.values()]) { this.removeSubscription(eventName, connection, true); } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index 83fe4e0da7..094762f27a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -1,11 +1,11 @@ -import invariant from "invariant"; +import type { OtlpExportTraceServiceRequestJson } from "@rivetkit/traces"; import { createTraces, type SpanHandle, type SpanStatusInput, type Traces, } from "@rivetkit/traces"; -import type { OtlpExportTraceServiceRequestJson } from "@rivetkit/traces"; +import invariant from "invariant"; import type { ActorKey } from "@/actor/mod"; import type { Client } from "@/client/client"; import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; @@ -18,7 +18,7 @@ import { CONN_VERSIONED, } from "@/schemas/actor-persist/versioned"; import { EXTRA_ERROR_LOG } from "@/utils"; -import { getRunFunction, type ActorConfig } from "../config"; +import { type ActorConfig, getRunFunction } from "../config"; import type { ConnDriver } from "../conn/driver"; import { createHttpDriver } from "../conn/drivers/http"; import { @@ -44,6 +44,7 @@ import * as errors from "../errors"; import { serializeActorKey } from "../keys"; import { processMessage } from "../protocol/old"; import { Schedule } from "../schedule"; +import type { SchemaConfig } from "../schema"; import { assertUnreachable, DeadlineError, @@ -53,7 +54,6 @@ import { import { ConnectionManager } from "./connection-manager"; import { EventManager } from "./event-manager"; import { KEYS } from "./keys"; -import { ActorTracesDriver } from "./traces-driver"; import { convertActorFromBarePersisted, type PersistedActor, @@ -61,6 +61,7 @@ import { import { QueueManager } from "./queue-manager"; import { ScheduleManager } from "./schedule-manager"; import { type SaveStateOptions, StateManager } from "./state-manager"; +import { ActorTracesDriver } from "./traces-driver"; export type { SaveStateOptions }; @@ -74,28 +75,46 @@ enum CanSleep { } /** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */ -export type AnyActorInstance = ActorInstance; +export type AnyActorInstance = ActorInstance< + any, + any, + any, + any, + any, + any, + any, + any +>; export type ExtractActorState = - A extends ActorInstance + A extends ActorInstance ? State : never; export type ExtractActorConnParams = - A extends ActorInstance + A extends ActorInstance ? ConnParams : never; export type ExtractActorConnState = - A extends ActorInstance + A extends ActorInstance ? ConnState : never; // MARK: - Main ActorInstance Class -export class ActorInstance { +export class ActorInstance< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, +> { // MARK: - Core Properties - actorContext: ActorContext; - #config: ActorConfig; + actorContext: ActorContext; + #config: ActorConfig; driver!: ActorDriver; #inlineClient!: Client>; #actorId!: string; @@ -105,15 +124,15 @@ export class ActorInstance { #region!: string; // MARK: - Managers - connectionManager!: ConnectionManager; + connectionManager!: ConnectionManager; - stateManager!: StateManager; + stateManager!: StateManager; - eventManager!: EventManager; + eventManager!: EventManager; - #scheduleManager!: ScheduleManager; + #scheduleManager!: ScheduleManager; - queueManager!: QueueManager; + queueManager!: QueueManager; // MARK: - Logging #log!: Logger; @@ -155,17 +174,15 @@ export class ActorInstance { // MARK: - Inspector #inspectorToken?: string; - #inspector!: ActorInspector; + #inspector = new ActorInspector(this); // MARK: - Tracing #traces!: Traces; // MARK: - Constructor - constructor(config: ActorConfig) { + constructor(config: ActorConfig) { this.#config = config; - this.#inspector = new ActorInspector(this); this.actorContext = new ActorContext(this); - this.#inspector = new ActorInspector(this); } // MARK: - Public Getters @@ -230,14 +247,8 @@ export class ActorInstance { }); } - endTraceSpan( - handle: SpanHandle, - status?: SpanStatusInput, - ): void { - this.#traces.endSpan( - handle, - status ? { status } : undefined, - ); + endTraceSpan(handle: SpanHandle, status?: SpanStatusInput): void { + this.#traces.endSpan(handle, status ? { status } : undefined); } async runInTraceSpan( @@ -248,8 +259,7 @@ export class ActorInstance { const span = this.startTraceSpan(name, attributes); try { const result = this.#traces.withSpan(span, fn); - const resolved = - result instanceof Promise ? await result : result; + const resolved = result instanceof Promise ? await result : result; this.#traces.endSpan(span, { status: { code: "OK" }, }); @@ -280,7 +290,7 @@ export class ActorInstance { }); } - get conns(): Map> { + get conns(): Map> { return this.connectionManager.connections; } @@ -296,7 +306,7 @@ export class ActorInstance { return Object.keys(this.#config.actions ?? {}); } - get config(): ActorConfig { + get config(): ActorConfig { return this.#config; } @@ -580,7 +590,7 @@ export class ActorInstance { val: { eventName: string; subscribe: boolean }; }; }, - conn: Conn, + conn: Conn, ) { await processMessage(message, this, conn, { onExecuteAction: async (ctx, name, args) => { @@ -597,7 +607,7 @@ export class ActorInstance { // MARK: - Action Execution async executeAction( - ctx: ActionContext, + ctx: ActionContext, actionName: string, args: unknown[], ): Promise { @@ -619,15 +629,9 @@ export class ActorInstance { throw new errors.ActionNotFound(actionName); } - this.#activeKeepAwakeCount++; - this.resetSleepTimer(); - - const actionSpan = this.startTraceSpan( - `actor.action.${actionName}`, - { - "rivet.action.name": actionName, - }, - ); + const actionSpan = this.startTraceSpan(`actor.action.${actionName}`, { + "rivet.action.name": actionName, + }); let spanEnded = false; try { @@ -645,12 +649,9 @@ export class ActorInstance { ); let output: unknown; - const maybeThenable = outputOrPromise as { - then?: (onfulfilled?: unknown, onrejected?: unknown) => unknown; - }; - if (maybeThenable && typeof maybeThenable.then === "function") { + if (outputOrPromise instanceof Promise) { output = await deadline( - Promise.resolve(outputOrPromise), + outputOrPromise, this.#config.options.actionTimeout, ); } else { @@ -710,22 +711,13 @@ export class ActorInstance { status: { code: "OK" }, }); } - this.#activeKeepAwakeCount--; - if (this.#activeKeepAwakeCount < 0) { - this.#activeKeepAwakeCount = 0; - this.#rLog.warn({ - msg: "active keep awake count went below 0, this is a RivetKit bug", - ...EXTRA_ERROR_LOG, - }); - } - this.resetSleepTimer(); this.stateManager.savePersistThrottled(); } } // MARK: - HTTP/WebSocket Handlers async handleRawRequest( - conn: Conn, + conn: Conn, request: Request, ): Promise { this.assertReady(); @@ -742,10 +734,10 @@ export class ActorInstance { "http.url": request.url, "rivet.conn.id": conn.id, }, - async () => { - try { - const ctx = new RequestContext(this, conn, request); - const response = await onRequest(ctx, request); + async () => { + try { + const ctx = new RequestContext(this, conn, request); + const response = await onRequest(ctx, request); if (!response) { throw new errors.InvalidRequestHandlerResponse(); } @@ -764,7 +756,7 @@ export class ActorInstance { } handleRawWebSocket( - conn: Conn, + conn: Conn, websocket: UniversalWebSocket, request?: Request, ) { @@ -950,7 +942,10 @@ export class ActorInstance { const [value] = args; if (typeof value === "string") { message = value; - } else if (typeof value === "number" || typeof value === "boolean") { + } else if ( + typeof value === "number" || + typeof value === "boolean" + ) { message = String(value); } else if (value && typeof value === "object") { const maybeMsg = (value as { msg?: unknown }).msg; @@ -1100,19 +1095,23 @@ export class ActorInstance { let vars: V | undefined; if ("createVars" in this.#config) { const createVars = this.#config.createVars; - vars = await this.runInTraceSpan("actor.createVars", undefined, () => { - const dataOrPromise = createVars!( - this.actorContext as any, - this.driver.getContext(this.#actorId), - ); - if (dataOrPromise instanceof Promise) { - return deadline( - dataOrPromise, - this.#config.options.createVarsTimeout, + vars = await this.runInTraceSpan( + "actor.createVars", + undefined, + () => { + const dataOrPromise = createVars!( + this.actorContext as any, + this.driver.getContext(this.#actorId), ); - } - return dataOrPromise; - }); + if (dataOrPromise instanceof Promise) { + return deadline( + dataOrPromise, + this.#config.options.createVarsTimeout, + ); + } + return dataOrPromise; + }, + ); } else if ("vars" in this.#config) { vars = structuredClone(this.#config.vars); } else { @@ -1138,15 +1137,19 @@ export class ActorInstance { const onSleep = this.#config.onSleep; try { this.#rLog.debug({ msg: "calling onSleep" }); - await this.runInTraceSpan("actor.onSleep", undefined, async () => { - const result = onSleep(this.actorContext); - if (result instanceof Promise) { - await deadline( - result, - this.#config.options.onSleepTimeout, - ); - } - }); + await this.runInTraceSpan( + "actor.onSleep", + undefined, + async () => { + const result = onSleep(this.actorContext); + if (result instanceof Promise) { + await deadline( + result, + this.#config.options.onSleepTimeout, + ); + } + }, + ); this.#rLog.debug({ msg: "onSleep completed" }); } catch (error) { if (error instanceof DeadlineError) { @@ -1162,34 +1165,23 @@ export class ActorInstance { } async #callOnDestroy() { - // Clean up database first - if ("db" in this.#config && this.#config.db && this.#db) { - try { - this.#rLog.debug({ msg: "cleaning up database" }); - await this.#config.db.onDestroy?.(this.#db); - this.#rLog.debug({ msg: "database cleanup completed" }); - } catch (error) { - this.#rLog.error({ - msg: "error cleaning up database", - error: stringifyError(error), - }); - } - } - - // Then call user's onDestroy if (this.#config.onDestroy) { const onDestroy = this.#config.onDestroy; try { this.#rLog.debug({ msg: "calling onDestroy" }); - await this.runInTraceSpan("actor.onDestroy", undefined, async () => { - const result = onDestroy(this.actorContext); - if (result instanceof Promise) { - await deadline( - result, - this.#config.options.onDestroyTimeout, - ); - } - }); + await this.runInTraceSpan( + "actor.onDestroy", + undefined, + async () => { + const result = onDestroy(this.actorContext); + if (result instanceof Promise) { + await deadline( + result, + this.#config.options.onDestroyTimeout, + ); + } + }, + ); this.#rLog.debug({ msg: "onDestroy completed" }); } catch (error) { if (error instanceof DeadlineError) { @@ -1284,48 +1276,13 @@ export class ActorInstance { async #setupDatabase() { if ("db" in this.#config && this.#config.db) { - try { - const client = await this.#config.db.createClient({ - actorId: this.#actorId, - overrideRawDatabaseClient: this.driver.overrideRawDatabaseClient - ? () => this.driver.overrideRawDatabaseClient!(this.#actorId) - : undefined, - overrideDrizzleDatabaseClient: this.driver - .overrideDrizzleDatabaseClient - ? () => this.driver.overrideDrizzleDatabaseClient!(this.#actorId) - : undefined, - kv: { - batchPut: (entries) => - this.driver.kvBatchPut(this.#actorId, entries), - batchGet: (keys) => this.driver.kvBatchGet(this.#actorId, keys), - batchDelete: (keys) => - this.driver.kvBatchDelete(this.#actorId, keys), - }, - sqliteVfs: this.driver.sqliteVfs, - }); - this.#rLog.info({ msg: "database migration starting" }); - await this.#config.db.onMigrate?.(client); - this.#rLog.info({ msg: "database migration complete" }); - this.#db = client; - } catch (error) { - // Ensure error is properly formatted - if (error instanceof Error) { - this.#rLog.error({ - msg: "database setup failed", - error: stringifyError(error), - }); - throw error; - } - const wrappedError = new Error( - `Database setup failed: ${String(error)}`, - ); - this.#rLog.error({ - msg: "database setup failed with non-Error object", - error: String(error), - errorType: typeof error, - }); - throw wrappedError; - } + const client = await this.#config.db.createClient({ + getDatabase: () => this.driver.getDatabase(this.#actorId), + }); + this.#rLog.info({ msg: "database migration starting" }); + await this.#config.db.onMigrate?.(client); + this.#rLog.info({ msg: "database migration complete" }); + this.#db = client; } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts index 73847f8115..8ae360d1ec 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts @@ -6,10 +6,11 @@ import { QUEUE_METADATA_VERSIONED, } from "@/schemas/actor-persist/versioned"; import { promiseWithResolvers } from "@/utils"; -import { loggerWithoutContext } from "@/actor/log"; import type { AnyDatabaseProvider } from "../database"; import type { ActorDriver } from "../driver"; import * as errors from "../errors"; +import type { SchemaConfig } from "../schema"; +import { validateSchema } from "../schema"; import { decodeQueueMessageKey, KEYS, makeQueueMessageKey } from "./keys"; import type { ActorInstance } from "./mod"; @@ -18,15 +19,6 @@ export interface QueueMessage { name: string; body: unknown; createdAt: number; - failureCount: number; - availableAt: number; - inFlight: boolean; - inFlightAt?: number; -} - -export interface QueueCompletionResult { - status: "completed" | "timedOut"; - response?: unknown; } interface QueueMetadata { @@ -34,38 +26,16 @@ interface QueueMetadata { size: number; } -interface EnqueueOptions { - deferWaiters?: boolean; -} - interface QueueWaiter { id: string; nameSet: Set; count: number; - wait: boolean; resolve: (messages: QueueMessage[]) => void; reject: (error: Error) => void; signal?: AbortSignal; timeoutHandle?: ReturnType; } -interface QueueNameWaiter { - id: string; - nameSet: Set; - resolve: () => void; - reject: (error: Error) => void; - signal?: AbortSignal; - abortHandler?: () => void; -} - -interface QueueCompletionWaiter { - id: string; - messageId: bigint; - resolve: (result: QueueCompletionResult) => void; - reject: (error: Error) => void; - timeoutHandle?: ReturnType; -} - interface MessageListener { nameSet: Set; resolve: () => void; @@ -80,25 +50,24 @@ const DEFAULT_METADATA: QueueMetadata = { size: 0, }; -const PENDING_WARNING_MS = 30_000; -const BACKOFF_INITIAL_MS = 1_000; -const BACKOFF_MAX_MS = 5 * 60_000; - -export class QueueManager { - #actor: ActorInstance; +export class QueueManager< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, +> { + #actor: ActorInstance; #driver: ActorDriver; #waiters = new Map(); - #nameWaiters = new Map(); - #completionWaiters = new Map(); #metadata: QueueMetadata = { ...DEFAULT_METADATA }; - #pendingMessageId: bigint | undefined; - #pendingWarningHandle: ReturnType | undefined; - #redeliveryTimeout: ReturnType | undefined; - #redeliveryAt: number | undefined; #messageListeners = new Set(); constructor( - actor: ActorInstance, + actor: ActorInstance, driver: ActorDriver, ) { this.#actor = actor; @@ -137,18 +106,22 @@ export class QueueManager { await this.#rebuildMetadata(); } this.#actor.inspector.updateQueueSize(this.#metadata.size); - - await this.#recoverInFlightMessages(); } /** Adds a message to the queue with the given name and body. */ - async enqueue( - name: string, - body: unknown, - options: EnqueueOptions = {}, - ): Promise { + async enqueue(name: string, body: unknown): Promise { this.#actor.assertReady(); + const validation = await validateSchema( + this.#actor.config.queues, + name as keyof Q & string, + body, + ); + if (!validation.success) { + throw new errors.QueuePayloadInvalid(name, validation.issues); + } + const validatedBody = validation.data; + const sizeLimit = this.#actor.config.options.maxQueueSize; if (this.#metadata.size >= sizeLimit) { throw new errors.QueueFull(sizeLimit); @@ -156,7 +129,7 @@ export class QueueManager { let invalidPath = ""; if ( - !isCborSerializable(body, (path) => { + !isCborSerializable(validatedBody, (path) => { invalidPath = path; }) ) { @@ -164,18 +137,13 @@ export class QueueManager { } const createdAt = Date.now(); - const bodyCborBuffer = cbor.encode(body); - const availableAt = createdAt; + const bodyCborBuffer = cbor.encode(validatedBody); const encodedMessage = QUEUE_MESSAGE_VERSIONED.serializeWithEmbeddedVersion( { name, body: new Uint8Array(bodyCborBuffer).buffer as ArrayBuffer, createdAt: BigInt(createdAt), - failureCount: 0, - availableAt: BigInt(availableAt), - inFlight: false, - inFlightAt: null, }, ACTOR_PERSIST_CURRENT_VERSION, ); @@ -206,68 +174,40 @@ export class QueueManager { const message: QueueMessage = { id, name, - body, + body: validatedBody, createdAt, - failureCount: 0, - availableAt, - inFlight: false, - inFlightAt: undefined, }; this.#actor.resetSleepTimer(); - if (!options.deferWaiters) { - await this.#maybeResolveWaiters(); - } + await this.#maybeResolveWaiters(); this.#notifyMessageListeners(name); return message; } - async enqueueAndWait( - name: string, - body: unknown, - timeout?: number, - ): Promise { - const message = await this.enqueue(name, body, { - deferWaiters: true, - }); - const completionPromise = this.waitForCompletion(message.id, timeout); - await this.#maybeResolveWaiters(); - return await completionPromise; - } - /** Receives messages from the queue matching the given names. Waits until messages are available or timeout is reached. */ async receive( names: string[], count: number, timeout?: number, abortSignal?: AbortSignal, - wait: boolean = false, ): Promise { this.#actor.assertReady(); - if (this.#pendingMessageId !== undefined) { - throw new errors.QueueMessagePending(); - } const limitedCount = Math.max(1, count); const nameSet = new Set(names); - const immediate = await this.#drainMessages( - nameSet, - limitedCount, - wait, - ); + const immediate = await this.#drainMessages(nameSet, limitedCount); if (immediate.length > 0 || timeout === 0) { return timeout === 0 && immediate.length === 0 ? [] : immediate; } const { promise, resolve, reject } = - promiseWithResolvers((reason) => loggerWithoutContext().warn({ msg: "unhandled queue message waiter rejection", reason })); + promiseWithResolvers(); const waiterId = crypto.randomUUID(); const waiter: QueueWaiter = { id: waiterId, nameSet, count: limitedCount, - wait, resolve, reject, signal: abortSignal, @@ -313,54 +253,6 @@ export class QueueManager { return promise; } - /** Waits for a specific queue message to complete. */ - async waitForCompletion( - messageId: bigint, - timeout?: number, - ): Promise { - const { promise, resolve, reject } = - promiseWithResolvers((reason) => loggerWithoutContext().warn({ msg: "unhandled queue completion waiter rejection", reason })); - const waiterId = crypto.randomUUID(); - - const waiter: QueueCompletionWaiter = { - id: waiterId, - messageId, - resolve, - reject, - }; - - if (timeout !== undefined) { - waiter.timeoutHandle = setTimeout(() => { - this.#completionWaiters.delete(messageId); - resolve({ status: "timedOut" }); - }, timeout); - } - - this.#completionWaiters.set(messageId, waiter); - return promise; - } - - /** Completes a pending message and optionally responds to any waiter. */ - async complete(message: QueueMessage, response?: unknown): Promise { - if (this.#pendingMessageId !== message.id) { - throw new errors.QueueAlreadyCompleted(); - } - this.#pendingMessageId = undefined; - if (this.#pendingWarningHandle) { - clearTimeout(this.#pendingWarningHandle); - this.#pendingWarningHandle = undefined; - } - - await this.#removeMessages([message], { resolveWaiters: false }); - this.#resolveCompletionWaiter(message.id, { - status: "completed", - response, - }); - - await this.#maybeResolveWaiters(); - } - - /** Waits for messages with any of the specified names to appear in the queue. */ async waitForNames( names: string[], abortSignal?: AbortSignal, @@ -419,10 +311,7 @@ export class QueueManager { } /** Deletes messages matching the provided IDs. Returns the IDs that were removed. */ - async deleteMessagesById( - ids: bigint[], - options: { resolveWaiters?: boolean } = {}, - ): Promise { + async deleteMessagesById(ids: bigint[]): Promise { if (ids.length === 0) { return []; } @@ -434,52 +323,26 @@ export class QueueManager { if (toRemove.length === 0) { return []; } - await this.#removeMessages(toRemove, { - resolveWaiters: options.resolveWaiters ?? true, - }); + await this.#removeMessages(toRemove); return toRemove.map((entry) => entry.id); } - /** Completes a previously removed message by resolving its waiter, if one exists. */ - async completeById(messageId: bigint, response?: unknown): Promise { - this.#resolveCompletionWaiter(messageId, { - status: "completed", - response, - }); - } - async #drainMessages( nameSet: Set, count: number, - wait: boolean, ): Promise { if (this.#metadata.size === 0) { return []; } - const now = Date.now(); const entries = await this.#loadQueueMessages(); - const matched = entries.filter( - (entry) => nameSet.has(entry.name) && !entry.inFlight, - ); + const matched = entries.filter((entry) => nameSet.has(entry.name)); if (matched.length === 0) { return []; } - const eligible = matched.filter((entry) => entry.availableAt <= now); - if (eligible.length === 0) { - this.#scheduleRedelivery(matched); - return []; - } - - const selected = eligible.slice(0, wait ? 1 : count); - if (wait) { - await this.#markMessageInFlight(selected[0], now); - return [selected[0]]; - } - - await this.#removeMessages(selected, { resolveWaiters: true }); - - // Emit trace events for received messages + const selected = matched.slice(0, count); + await this.#removeMessages(selected); + const now = Date.now(); for (const message of selected) { this.#actor.emitTraceEvent("queue.message.receive", { "rivet.queue.name": message.name, @@ -488,7 +351,6 @@ export class QueueManager { "rivet.queue.latency_ms": now - message.createdAt, }); } - return selected; } @@ -506,31 +368,11 @@ export class QueueManager { value, ); const body = cbor.decode(new Uint8Array(decodedPayload.body)); - const failureCount = - decodedPayload.failureCount !== undefined && - decodedPayload.failureCount !== null - ? Number(decodedPayload.failureCount) - : 0; - const availableAt = - decodedPayload.availableAt !== undefined && - decodedPayload.availableAt !== null - ? Number(decodedPayload.availableAt) - : Number(decodedPayload.createdAt); - const inFlight = decodedPayload.inFlight ?? false; - const inFlightAt = - decodedPayload.inFlightAt !== undefined && - decodedPayload.inFlightAt !== null - ? Number(decodedPayload.inFlightAt) - : undefined; decoded.push({ id: messageId, name: decodedPayload.name, body, createdAt: Number(decodedPayload.createdAt), - failureCount, - availableAt, - inFlight, - inFlightAt, }); } catch (error) { this.#actor.rLog.error({ @@ -567,10 +409,7 @@ export class QueueManager { } } - async #removeMessages( - messages: QueueMessage[], - options: { resolveWaiters: boolean }, - ): Promise { + async #removeMessages(messages: QueueMessage[]): Promise { if (messages.length === 0) { return; } @@ -590,65 +429,10 @@ export class QueueManager { ]); this.#actor.inspector.updateQueueSize(this.#metadata.size); - - if (options.resolveWaiters) { - for (const message of messages) { - this.#resolveCompletionWaiter(message.id, { - status: "completed", - response: undefined, - }); - } - } } async #maybeResolveWaiters() { - if (this.#pendingMessageId !== undefined) { - return; - } - if (this.#redeliveryTimeout) { - clearTimeout(this.#redeliveryTimeout); - this.#redeliveryTimeout = undefined; - this.#redeliveryAt = undefined; - } - const hasReceiveWaiters = this.#waiters.size > 0; - const hasNameWaiters = this.#nameWaiters.size > 0; - if (!hasReceiveWaiters && !hasNameWaiters) { - return; - } - - if (hasNameWaiters) { - const entries = await this.#loadQueueMessages(); - const now = Date.now(); - const nameWaiters = [...this.#nameWaiters.values()]; - for (const waiter of nameWaiters) { - if (waiter.signal?.aborted) { - this.#nameWaiters.delete(waiter.id); - waiter.reject(new errors.ActorAborted()); - continue; - } - - const hasMatch = entries.some( - (message) => - waiter.nameSet.has(message.name) && - !message.inFlight && - message.availableAt <= now, - ); - if (!hasMatch) { - continue; - } - - this.#nameWaiters.delete(waiter.id); - if (waiter.abortHandler) { - waiter.signal?.removeEventListener( - "abort", - waiter.abortHandler, - ); - } - waiter.resolve(); - } - } - - if (!hasReceiveWaiters) { + if (this.#waiters.size === 0) { return; } const pending = [...this.#waiters.values()]; @@ -662,7 +446,6 @@ export class QueueManager { const messages = await this.#drainMessages( waiter.nameSet, waiter.count, - waiter.wait, ); if (messages.length === 0) { continue; @@ -672,9 +455,6 @@ export class QueueManager { clearTimeout(waiter.timeoutHandle); } waiter.resolve(messages); - if (waiter.wait) { - break; - } } } @@ -706,167 +486,6 @@ export class QueueManager { this.#actor.inspector.updateQueueSize(this.#metadata.size); } - async #markMessageInFlight( - message: QueueMessage, - now: number, - ): Promise { - if (message.inFlight) { - throw new errors.QueueMessagePending(); - } - - message.inFlight = true; - message.inFlightAt = now; - - await this.#persistMessage(message); - - this.#pendingMessageId = message.id; - this.#pendingWarningHandle = setTimeout(() => { - if (this.#pendingMessageId === message.id) { - this.#actor.rLog.warn({ - msg: "queue message pending for over 30s", - messageId: message.id.toString(), - name: message.name, - }); - } - }, PENDING_WARNING_MS); - } - - async #persistMessage(message: QueueMessage): Promise { - const bodyCborBuffer = cbor.encode(message.body); - const encodedMessage = - QUEUE_MESSAGE_VERSIONED.serializeWithEmbeddedVersion( - { - name: message.name, - body: new Uint8Array(bodyCborBuffer).buffer as ArrayBuffer, - createdAt: BigInt(message.createdAt), - failureCount: message.failureCount, - availableAt: BigInt(message.availableAt), - inFlight: message.inFlight, - inFlightAt: - message.inFlightAt !== undefined - ? BigInt(message.inFlightAt) - : null, - }, - ACTOR_PERSIST_CURRENT_VERSION, - ); - - await this.#driver.kvBatchPut(this.#actor.id, [ - [makeQueueMessageKey(message.id), encodedMessage], - ]); - } - - async #recoverInFlightMessages(): Promise { - const entries = await this.#driver.kvListPrefix( - this.#actor.id, - KEYS.QUEUE_PREFIX, - ); - - const updates: [Uint8Array, Uint8Array][] = []; - const now = Date.now(); - - for (const [key, value] of entries) { - try { - const messageId = decodeQueueMessageKey(key); - const decodedPayload = - QUEUE_MESSAGE_VERSIONED.deserializeWithEmbeddedVersion( - value, - ); - const inFlight = decodedPayload.inFlight ?? false; - if (!inFlight) { - continue; - } - - const failureCount = - (decodedPayload.failureCount !== undefined && - decodedPayload.failureCount !== null - ? Number(decodedPayload.failureCount) - : 0) + 1; - const availableAt = now + this.#computeBackoffMs(failureCount); - - const updatedMessage = - QUEUE_MESSAGE_VERSIONED.serializeWithEmbeddedVersion( - { - name: decodedPayload.name, - body: decodedPayload.body, - createdAt: decodedPayload.createdAt, - failureCount, - availableAt: BigInt(availableAt), - inFlight: false, - inFlightAt: null, - }, - ACTOR_PERSIST_CURRENT_VERSION, - ); - - updates.push([key, updatedMessage]); - - this.#actor.rLog.warn({ - msg: "recovering in-flight queue message", - messageId: messageId.toString(), - failureCount, - availableAt, - }); - } catch (error) { - this.#actor.rLog.error({ - msg: "failed to recover in-flight queue message", - error, - }); - } - } - - if (updates.length > 0) { - await this.#driver.kvBatchPut(this.#actor.id, updates); - } - } - - #scheduleRedelivery(messages: QueueMessage[]): void { - if (messages.length === 0) { - return; - } - const nextAvailableAt = messages.reduce((min, message) => { - return message.availableAt < min ? message.availableAt : min; - }, messages[0].availableAt); - - if ( - this.#redeliveryAt !== undefined && - this.#redeliveryAt <= nextAvailableAt - ) { - return; - } - - if (this.#redeliveryTimeout) { - clearTimeout(this.#redeliveryTimeout); - } - - const delay = Math.max(0, nextAvailableAt - Date.now()); - this.#redeliveryAt = nextAvailableAt; - this.#redeliveryTimeout = setTimeout(() => { - this.#redeliveryTimeout = undefined; - this.#redeliveryAt = undefined; - void this.#maybeResolveWaiters(); - }, delay); - } - - #resolveCompletionWaiter( - messageId: bigint, - result: QueueCompletionResult, - ): void { - const waiter = this.#completionWaiters.get(messageId); - if (!waiter) { - return; - } - this.#completionWaiters.delete(messageId); - if (waiter.timeoutHandle) { - clearTimeout(waiter.timeoutHandle); - } - waiter.resolve(result); - } - - #computeBackoffMs(failureCount: number): number { - const exp = Math.max(0, failureCount - 1); - const delay = Math.min(BACKOFF_MAX_MS, BACKOFF_INITIAL_MS * 2 ** exp); - return delay; - } - #serializeMetadata(): Uint8Array { return QUEUE_METADATA_VERSIONED.serializeWithEmbeddedVersion( { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue.ts index 638d18846c..bf320b6a8e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue.ts @@ -1,6 +1,6 @@ import type { AnyDatabaseProvider } from "../database"; -import * as errors from "../errors"; -import type { QueueManager, QueueMessage as QueueMessageRecord } from "./queue-manager"; +import type { InferSchemaMap, SchemaConfig } from "../schema"; +import type { QueueManager, QueueMessage } from "./queue-manager"; /** Options for receiving messages from the queue. */ export interface QueueReceiveOptions { @@ -8,8 +8,6 @@ export interface QueueReceiveOptions { count?: number; /** Timeout in milliseconds to wait for messages. Waits indefinitely if not specified. */ timeout?: number; - /** When true, message must be manually completed. */ - wait?: boolean; } /** Request object for receiving messages from the queue. */ @@ -18,13 +16,26 @@ export interface QueueReceiveRequest extends QueueReceiveOptions { name: string | string[]; } +export type QueueMessageOf = Omit & { + body: Body; +}; + /** User-facing queue interface exposed on ActorContext. */ -export class ActorQueue { - #queueManager: QueueManager; +export class ActorQueue< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, +> { + #queueManager: QueueManager; #abortSignal: AbortSignal; constructor( - queueManager: QueueManager, + queueManager: QueueManager, abortSignal: AbortSignal, ) { this.#queueManager = queueManager; @@ -32,17 +43,35 @@ export class ActorQueue { } /** Receives the next message from a single queue. Returns undefined if no message available. */ + next( + name: K, + opts?: QueueReceiveOptions, + ): Promise[K]> | undefined>; next( - name: string, + name: keyof TQueues extends never ? string : never, opts?: QueueReceiveOptions, ): Promise; /** Receives messages from multiple queues. Returns messages matching any of the queue names. */ + next( + name: K[], + opts?: QueueReceiveOptions, + ): Promise[K]>> | undefined>; next( - name: string[], + name: keyof TQueues extends never ? string[] : never, opts?: QueueReceiveOptions, ): Promise; /** Receives messages using a request object for full control over options. */ - next(request: QueueReceiveRequest): Promise; + next( + request: QueueReceiveRequest & { name: K }, + ): Promise[K]> | undefined>; + next( + request: QueueReceiveRequest & { name: K[] }, + ): Promise[K]>> | undefined>; + next( + request: QueueReceiveRequest & { + name: keyof TQueues extends never ? string | string[] : never; + }, + ): Promise; async next( nameOrRequest: string | string[] | QueueReceiveRequest, opts: QueueReceiveOptions = {}, @@ -62,51 +91,29 @@ export class ActorQueue { count, mergedOptions.timeout, this.#abortSignal, - mergedOptions.wait ?? false, ); if (Array.isArray(request.name)) { - return messages?.map((message) => - this.#toQueueMessage(message, mergedOptions.wait ?? false), - ); + return messages; } if (!messages || messages.length === 0) { return undefined; } - return this.#toQueueMessage(messages[0], mergedOptions.wait ?? false); - } - - #toQueueMessage( - message: QueueMessageRecord, - wait: boolean, - ): QueueMessage { - const base: QueueMessage = { - id: message.id.toString(), - name: message.name, - body: message.body, - complete: async (data?: unknown) => { - if (!wait) { - throw new errors.QueueCompleteNotAllowed(); - } - await this.#queueManager.complete(message, data); - }, - }; - - return base; + return messages[0]; } /** Sends a message to the specified queue. */ + send( + name: K, + body: InferSchemaMap[K], + ): Promise; + send( + name: keyof TQueues extends never ? string : never, + body: unknown, + ): Promise; async send(name: string, body: unknown): Promise { - const message = await this.#queueManager.enqueue(name, body); - return this.#toQueueMessage(message, false); + return await this.#queueManager.enqueue(name, body); } } - -export interface QueueMessage { - name: string; - body: T; - id: string; - complete(data?: unknown): Promise; -} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts index 638827fcac..ada9e089e1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts @@ -6,6 +6,7 @@ import { } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; import type { ActorDriver } from "../driver"; +import type { SchemaConfig } from "../schema"; import type { ActorInstance } from "./mod"; import type { PersistedScheduleEvent } from "./persisted"; @@ -13,15 +14,24 @@ import type { PersistedScheduleEvent } from "./persisted"; * Manages scheduled events and alarms for actor instances. * Handles event scheduling, alarm triggers, and automatic event execution. */ -export class ScheduleManager { - #actor: ActorInstance; +export class ScheduleManager< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, +> { + #actor: ActorInstance; #actorDriver: ActorDriver; #alarmWriteQueue = new SinglePromiseQueue(); #config: any; // ActorConfig type #persist: any; // Reference to PersistedActor constructor( - actor: ActorInstance, + actor: ActorInstance, actorDriver: ActorDriver, config: any, ) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts index 38bd6fb2b9..245887c8af 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts @@ -12,6 +12,7 @@ import { type AnyConn, CONN_STATE_MANAGER_SYMBOL } from "../conn/mod"; import { convertConnToBarePersistedConn } from "../conn/persisted"; import type { ActorDriver } from "../driver"; import * as errors from "../errors"; +import type { SchemaConfig } from "../schema"; import { isConnStatePath, isStatePath } from "../utils"; import { KEYS, makeConnKey } from "./keys"; import type { ActorInstance } from "./mod"; @@ -36,8 +37,15 @@ export interface SaveStateOptions { * Manages actor state persistence, proxying, and synchronization. * Handles automatic state change detection and throttled persistence to KV storage. */ -export class StateManager { - #actor: ActorInstance; +export class StateManager< + S, + CP, + CS, + I, + E extends SchemaConfig = Record, + Q extends SchemaConfig = Record, +> { + #actor: ActorInstance; #actorDriver: ActorDriver; // State tracking @@ -58,7 +66,7 @@ export class StateManager { #stateSaveInterval: number; constructor( - actor: ActorInstance, + actor: ActorInstance, actorDriver: ActorDriver, config: any, ) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts index d9906fbc9a..39550f669a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts @@ -7,6 +7,7 @@ import { } from "./config"; import type { AnyDatabaseProvider } from "./database"; import { ActorDefinition } from "./definition"; +import type { SchemaConfig } from "./schema"; export function actor< TState, @@ -15,13 +16,26 @@ export function actor< TVars, TInput, TDatabase extends AnyDatabaseProvider, + TEvents extends SchemaConfig = Record, + TQueues extends SchemaConfig = Record, TActions extends Actions< TState, TConnParams, TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues + > = Actions< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues >, >( input: ActorConfigInput< @@ -31,6 +45,8 @@ export function actor< TVars, TInput, TDatabase, + TEvents, + TQueues, TActions >, ): ActorDefinition< @@ -40,6 +56,8 @@ export function actor< TVars, TInput, TDatabase, + TEvents, + TQueues, TActions > { const config = ActorConfigSchema.parse(input) as ActorConfig< @@ -48,7 +66,9 @@ export function actor< TConnState, TVars, TInput, - TDatabase + TDatabase, + TEvents, + TQueues >; return new ActorDefinition(config); } @@ -84,3 +104,4 @@ export { createActorRouter, } from "./router"; export { routeWebSocket } from "./router-websocket-endpoints"; +export { type Raw, raw } from "./schema"; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index 5a6654f684..cd0cf8556d 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -29,6 +29,7 @@ import { import { CONN_SEND_MESSAGE_SYMBOL, type Conn } from "../conn/mod"; import { ActionContext } from "../contexts"; import type { ActorInstance } from "../instance/mod"; +import type { SchemaConfig } from "../schema"; interface MessageEventOpts { encoding: Encoding; @@ -139,19 +140,21 @@ export interface ProcessMessageHandler< V, I, DB extends AnyDatabaseProvider, + E extends SchemaConfig, + Q extends SchemaConfig, > { onExecuteAction?: ( - ctx: ActionContext, + ctx: ActionContext, name: string, args: unknown[], ) => Promise; onSubscribe?: ( eventName: string, - conn: Conn, + conn: Conn, ) => Promise; onUnsubscribe?: ( eventName: string, - conn: Conn, + conn: Conn, ) => Promise; } @@ -162,6 +165,8 @@ export async function processMessage< V, I, DB extends AnyDatabaseProvider, + E extends SchemaConfig, + Q extends SchemaConfig, >( message: { body: @@ -174,9 +179,9 @@ export async function processMessage< val: { eventName: string; subscribe: boolean }; }; }, - actor: ActorInstance, - conn: Conn, - handler: ProcessMessageHandler, + actor: ActorInstance, + conn: Conn, + handler: ProcessMessageHandler, ) { let actionId: bigint | undefined; let actionName: string | undefined; @@ -199,7 +204,10 @@ export async function processMessage< actionName: name, }); - const ctx = new ActionContext(actor, conn); + const ctx = new ActionContext( + actor, + conn, + ); // Process the action request and wait for the result // This will wait for async actions to complete diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/schema.ts b/rivetkit-typescript/packages/rivetkit/src/actor/schema.ts new file mode 100644 index 0000000000..a54ed2438e --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/schema.ts @@ -0,0 +1,99 @@ +import type { StandardSchemaV1 } from "@standard-schema/spec"; +import { Unsupported } from "./errors"; + +export const RAW_MARKER = Symbol.for("rivetkit.raw"); + +export type Raw = { + [RAW_MARKER]: true; + _type: T; +}; + +export function raw(): Raw { + return { [RAW_MARKER]: true } as Raw; +} + +export type Schema = StandardSchemaV1 | Raw; + +export type SchemaConfig = Record; + +export type InferSchema = + T extends StandardSchemaV1 + ? O + : T extends Raw + ? R + : never; + +export type InferSchemaMap = { + [K in keyof T]: InferSchema; +}; + +export type InferEventArgs = T extends readonly unknown[] + ? number extends T["length"] + ? [T] + : T + : [T]; + +export type ValidationResult = + | { success: true; data: T } + | { success: false; issues: unknown[] }; + +export function isStandardSchema(value: unknown): value is StandardSchemaV1 { + return typeof value === "object" && value !== null && "~standard" in value; +} + +export function isRaw(value: unknown): value is Raw { + return typeof value === "object" && value !== null && RAW_MARKER in value; +} + +export async function validateSchema( + schemas: T | undefined, + key: keyof T & string, + data: unknown, +): Promise[typeof key]>> { + const schema = schemas?.[key]; + + if (!schema || isRaw(schema)) { + return { success: true, data: data as InferSchemaMap[typeof key] }; + } + + if (isStandardSchema(schema)) { + const result = await schema["~standard"].validate(data); + if ("issues" in result) { + return { success: false, issues: result.issues }; + } + return { + success: true, + data: result.value as InferSchemaMap[typeof key], + }; + } + + return { success: true, data: data as InferSchemaMap[typeof key] }; +} + +export function validateSchemaSync( + schemas: T | undefined, + key: keyof T & string, + data: unknown, +): ValidationResult[typeof key]> { + const schema = schemas?.[key]; + + if (!schema || isRaw(schema)) { + return { success: true, data: data as InferSchemaMap[typeof key] }; + } + + if (isStandardSchema(schema)) { + const result = schema["~standard"].validate(data); + if (result && typeof (result as Promise).then === "function") { + throw new Unsupported("async schema validation"); + } + if ("issues" in result) { + return { success: false, issues: result.issues }; + } + return { + success: true, + data: result.value as InferSchemaMap[typeof key], + }; + } + + return { success: true, data: data as InferSchemaMap[typeof key] }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-common.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-common.ts index 64284a7c0f..3a71d4cab2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-common.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-common.ts @@ -21,7 +21,7 @@ export type ActorActionFunction< */ export type ActorDefinitionActions = // biome-ignore lint/suspicious/noExplicitAny: safe to use any here - AD extends ActorDefinition + AD extends ActorDefinition ? { [K in keyof R]: R[K] extends ( ...args: infer Args diff --git a/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts b/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts index e3ff8c3b08..a4c7e71ac6 100644 --- a/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts +++ b/rivetkit-typescript/packages/rivetkit/src/registry/config/index.ts @@ -1,9 +1,10 @@ -import { z } from "zod/v4"; -import type { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; +import { z } from "zod"; import { getRunMetadata } from "@/actor/config"; +import type { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; import { type Logger, LogLevelSchema } from "@/common/log"; import { ENGINE_ENDPOINT } from "@/engine-process/constants"; import { InspectorConfigSchema } from "@/inspector/config"; +import { DeepReadonly } from "@/utils"; import { tryParseEndpoint } from "@/utils/endpoint-parser"; import { getRivetEndpoint, @@ -16,13 +17,12 @@ import { import { type DriverConfig, DriverConfigSchema } from "./driver"; import { RunnerConfigSchema } from "./runner"; import { ServerlessConfigSchema } from "./serverless"; -import { DeepReadonly } from "@/utils"; export { DriverConfigSchema, type DriverConfig }; export const ActorsSchema = z.record( z.string(), - z.custom>(), + z.custom>(), ); export type RegistryActors = z.infer; @@ -302,9 +302,22 @@ export function buildActorNames( export const DocInspectorConfigSchema = z .object({ - enabled: z.boolean().optional().describe("Whether to enable the Rivet Inspector. Defaults to true in development mode."), - token: z.string().optional().describe("Token used to access the Inspector."), - defaultEndpoint: z.string().optional().describe("Default RivetKit server endpoint for Rivet Inspector to connect to."), + enabled: z + .boolean() + .optional() + .describe( + "Whether to enable the Rivet Inspector. Defaults to true in development mode.", + ), + token: z + .string() + .optional() + .describe("Token used to access the Inspector."), + defaultEndpoint: z + .string() + .optional() + .describe( + "Default RivetKit server endpoint for Rivet Inspector to connect to.", + ), }) .optional() .describe("Inspector configuration for debugging and development."); @@ -312,54 +325,184 @@ export const DocInspectorConfigSchema = z export const DocConfigureRunnerPoolSchema = z .object({ name: z.string().optional().describe("Name of the runner pool."), - url: z.string().describe("URL of the serverless platform to configure runners."), - headers: z.record(z.string(), z.string()).optional().describe("Headers to include in requests to the serverless platform."), - maxRunners: z.number().optional().describe("Maximum number of runners in the pool."), - minRunners: z.number().optional().describe("Minimum number of runners to keep warm."), - requestLifespan: z.number().optional().describe("Maximum lifespan of a request in milliseconds."), - runnersMargin: z.number().optional().describe("Buffer margin for scaling runners."), - slotsPerRunner: z.number().optional().describe("Number of actor slots per runner."), - metadata: z.record(z.string(), z.unknown()).optional().describe("Additional metadata to pass to the serverless platform."), - metadataPollInterval: z.number().optional().describe("Interval in milliseconds between metadata polls from the engine. Defaults to 10000 milliseconds (10 seconds)."), + url: z + .string() + .describe("URL of the serverless platform to configure runners."), + headers: z + .record(z.string(), z.string()) + .optional() + .describe( + "Headers to include in requests to the serverless platform.", + ), + maxRunners: z + .number() + .optional() + .describe("Maximum number of runners in the pool."), + minRunners: z + .number() + .optional() + .describe("Minimum number of runners to keep warm."), + requestLifespan: z + .number() + .optional() + .describe("Maximum lifespan of a request in milliseconds."), + runnersMargin: z + .number() + .optional() + .describe("Buffer margin for scaling runners."), + slotsPerRunner: z + .number() + .optional() + .describe("Number of actor slots per runner."), + metadata: z + .record(z.string(), z.unknown()) + .optional() + .describe( + "Additional metadata to pass to the serverless platform.", + ), + metadataPollInterval: z + .number() + .optional() + .describe( + "Interval in milliseconds between metadata polls from the engine. Defaults to 10000 milliseconds (10 seconds).", + ), }) .optional(); -export const DocServerlessConfigSchema = z.object({ - spawnEngine: z.boolean().optional().describe("Downloads and starts the full Rust engine process. Auto-enabled in development mode when no endpoint is provided. Default: false"), - engineVersion: z.string().optional().describe("Version of the engine to download. Defaults to the current RivetKit version."), - configureRunnerPool: DocConfigureRunnerPoolSchema.describe("Automatically configure serverless runners in the engine."), - basePath: z.string().optional().describe("Base path for serverless API routes. Default: '/api/rivet'"), - publicEndpoint: z.string().optional().describe("The endpoint that clients should connect to. Supports URL auth syntax: https://namespace:token@api.rivet.dev"), - publicToken: z.string().optional().describe("Token that clients should use when connecting via the public endpoint."), -}).describe("Configuration for serverless deployment mode."); - -export const DocRunnerConfigSchema = z.object({ - totalSlots: z.number().optional().describe("Total number of actor slots available. Default: 100000"), - runnerName: z.string().optional().describe("Name of this runner. Default: 'default'"), - runnerKey: z.string().optional().describe("Authentication key for the runner."), - version: z.number().optional().describe("Version number of this runner. Default: 1"), -}).describe("Configuration for runner mode."); +export const DocServerlessConfigSchema = z + .object({ + spawnEngine: z + .boolean() + .optional() + .describe( + "Downloads and starts the full Rust engine process. Auto-enabled in development mode when no endpoint is provided. Default: false", + ), + engineVersion: z + .string() + .optional() + .describe( + "Version of the engine to download. Defaults to the current RivetKit version.", + ), + configureRunnerPool: DocConfigureRunnerPoolSchema.describe( + "Automatically configure serverless runners in the engine.", + ), + basePath: z + .string() + .optional() + .describe( + "Base path for serverless API routes. Default: '/api/rivet'", + ), + publicEndpoint: z + .string() + .optional() + .describe( + "The endpoint that clients should connect to. Supports URL auth syntax: https://namespace:token@api.rivet.dev", + ), + publicToken: z + .string() + .optional() + .describe( + "Token that clients should use when connecting via the public endpoint.", + ), + }) + .describe("Configuration for serverless deployment mode."); + +export const DocRunnerConfigSchema = z + .object({ + totalSlots: z + .number() + .optional() + .describe("Total number of actor slots available. Default: 100000"), + runnerName: z + .string() + .optional() + .describe("Name of this runner. Default: 'default'"), + runnerKey: z + .string() + .optional() + .describe("Authentication key for the runner."), + version: z + .number() + .optional() + .describe("Version number of this runner. Default: 1"), + }) + .describe("Configuration for runner mode."); export const DocRegistryConfigSchema = z .object({ - use: z.record(z.string(), z.unknown()).describe("Actor definitions. Keys are actor names, values are actor definitions."), - storagePath: z.string().optional().describe("Storage path for RivetKit file-system state when using the default driver. Can also be set via RIVETKIT_STORAGE_PATH."), - maxIncomingMessageSize: z.number().optional().describe("Maximum size of incoming WebSocket messages in bytes. Default: 65536"), - maxOutgoingMessageSize: z.number().optional().describe("Maximum size of outgoing WebSocket messages in bytes. Default: 1048576"), - noWelcome: z.boolean().optional().describe("Disable the welcome message on startup. Default: false"), + use: z + .record(z.string(), z.unknown()) + .describe( + "Actor definitions. Keys are actor names, values are actor definitions.", + ), + storagePath: z + .string() + .optional() + .describe( + "Storage path for RivetKit file-system state when using the default driver. Can also be set via RIVETKIT_STORAGE_PATH.", + ), + maxIncomingMessageSize: z + .number() + .optional() + .describe( + "Maximum size of incoming WebSocket messages in bytes. Default: 65536", + ), + maxOutgoingMessageSize: z + .number() + .optional() + .describe( + "Maximum size of outgoing WebSocket messages in bytes. Default: 1048576", + ), + noWelcome: z + .boolean() + .optional() + .describe("Disable the welcome message on startup. Default: false"), logging: z .object({ - level: LogLevelSchema.optional().describe("Log level for RivetKit. Default: 'warn'"), + level: LogLevelSchema.optional().describe( + "Log level for RivetKit. Default: 'warn'", + ), }) .optional() .describe("Logging configuration."), - endpoint: z.string().optional().describe("Endpoint URL to connect to Rivet Engine. Supports URL auth syntax: https://namespace:token@api.rivet.dev. Can also be set via RIVET_ENDPOINT environment variable."), - token: z.string().optional().describe("Authentication token for Rivet Engine. Can also be set via RIVET_TOKEN environment variable."), - namespace: z.string().optional().describe("Namespace to use. Default: 'default'. Can also be set via RIVET_NAMESPACE environment variable."), - headers: z.record(z.string(), z.string()).optional().describe("Additional headers to include in requests to Rivet Engine."), - serveManager: z.boolean().optional().describe("Whether to start the local manager server. Auto-determined based on endpoint and NODE_ENV if not specified."), - managerBasePath: z.string().optional().describe("Base path for the manager API. Default: '/'"), - managerPort: z.number().optional().describe("Port to run the manager on. Default: 6420"), + endpoint: z + .string() + .optional() + .describe( + "Endpoint URL to connect to Rivet Engine. Supports URL auth syntax: https://namespace:token@api.rivet.dev. Can also be set via RIVET_ENDPOINT environment variable.", + ), + token: z + .string() + .optional() + .describe( + "Authentication token for Rivet Engine. Can also be set via RIVET_TOKEN environment variable.", + ), + namespace: z + .string() + .optional() + .describe( + "Namespace to use. Default: 'default'. Can also be set via RIVET_NAMESPACE environment variable.", + ), + headers: z + .record(z.string(), z.string()) + .optional() + .describe( + "Additional headers to include in requests to Rivet Engine.", + ), + serveManager: z + .boolean() + .optional() + .describe( + "Whether to start the local manager server. Auto-determined based on endpoint and NODE_ENV if not specified.", + ), + managerBasePath: z + .string() + .optional() + .describe("Base path for the manager API. Default: '/'"), + managerPort: z + .number() + .optional() + .describe("Port to run the manager on. Default: 6420"), inspector: DocInspectorConfigSchema, serverless: DocServerlessConfigSchema.optional(), runner: DocRunnerConfigSchema.optional(), diff --git a/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts b/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts index f469e77e48..c1108ecc9f 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts @@ -42,6 +42,8 @@ describe("ActorDefinition", () => { TestVars, TestInput, TestDatabase, + Record, + Record, TestActions >; @@ -55,7 +57,9 @@ describe("ActorDefinition", () => { TestConnState, TestVars, TestInput, - TestDatabase + TestDatabase, + Record, + Record > >(); @@ -73,7 +77,9 @@ describe("ActorDefinition", () => { TestConnState, TestVars, TestInput, - TestDatabase + TestDatabase, + Record, + Record > >(); });