Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import { actor, event, queue } from "rivetkit";

interface AccessControlConnParams {
allowRequest?: boolean;
allowWebSocket?: boolean;
invalidCanInvokeReturn?: boolean;
}

export const accessControlActor = actor({
state: {
lastCanInvokeConnId: "",
},
events: {
allowedEvent: event<{ value: string }>(),
blockedEvent: event<{ value: string }>(),
},
queues: {
allowedQueue: queue<{ value: string }>(),
blockedQueue: queue<{ value: string }>(),
},
canInvoke: (c, invoke) => {
c.state.lastCanInvokeConnId = c.conn.id;
const params = c.conn.params as AccessControlConnParams | undefined;
if (params?.invalidCanInvokeReturn) {
return undefined as never;
}

if (invoke.kind === "action") {
if (invoke.name.startsWith("allowed")) {
return true;
}
return false;
}

if (invoke.kind === "queue") {
if (invoke.name === "allowedQueue") {
return true;
}
return false;
}

if (invoke.kind === "subscribe") {
if (invoke.name === "allowedEvent") {
return true;
}
return false;
}

if (invoke.kind === "request") {
if (params?.allowRequest === true) {
return true;
}
return false;
}

if (invoke.kind === "websocket") {
if (params?.allowWebSocket === true) {
return true;
}
return false;
}

return false;
},
onRequest(_c, request) {
const url = new URL(request.url);
if (url.pathname === "/status") {
return Response.json({ ok: true });
}
return new Response("Not Found", { status: 404 });
},
onWebSocket(_c, websocket) {
websocket.send(JSON.stringify({ type: "welcome" }));
},
actions: {
allowedAction: (_c, value: string) => {
return `allowed:${value}`;
},
blockedAction: () => {
return "blocked";
},
allowedGetLastCanInvokeConnId: (c) => {
return c.state.lastCanInvokeConnId;
},
allowedReceiveQueue: async (c) => {
const [message] = await c.queue.tryNext({
names: ["allowedQueue"],
});
return message?.body ?? null;
},
allowedBroadcastAllowedEvent: (c, value: string) => {
c.broadcast("allowedEvent", { value });
},
},
});
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { setup } from "rivetkit";
import { accessControlActor } from "./access-control";

import { inputActor } from "./action-inputs";
import {
Expand Down Expand Up @@ -160,5 +161,7 @@ export const registry = setup({
dbActorDrizzle,
// From stateless.ts
statelessActor,
// From access-control.ts
accessControlActor,
},
});
87 changes: 87 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type {
ActorContext,
BeforeActionResponseContext,
BeforeConnectContext,
ConnContext,
ConnectContext,
CreateConnStateContext,
CreateContext,
Expand Down Expand Up @@ -184,6 +185,7 @@ export const ActorConfigSchema = z
run: zRunHandler,
onStateChange: zFunction().optional(),
onBeforeConnect: zFunction().optional(),
canInvoke: zFunction().optional(),
onConnect: zFunction().optional(),
onDisconnect: zFunction().optional(),
onBeforeActionResponse: zFunction().optional(),
Expand Down Expand Up @@ -399,6 +401,60 @@ export interface Actions<
*/
export type AuthIntent = "get" | "create" | "connect" | "action" | "message";

type CanInvokeActionName<TActions> = keyof TActions extends never
? string
: keyof TActions & string;

type CanInvokeSubscribeName<TEvents extends EventSchemaConfig> =
keyof TEvents extends never ? string : keyof TEvents & string;

type CanInvokeQueueName<TQueues extends QueueSchemaConfig> =
keyof TQueues extends never ? string : keyof TQueues & string;

export type CanInvokeTarget<
TActions,
TEvents extends EventSchemaConfig,
TQueues extends QueueSchemaConfig,
> =
| {
kind: "action";
name: CanInvokeActionName<TActions>;
}
| {
kind: "subscribe";
name: CanInvokeSubscribeName<TEvents>;
}
| {
kind: "queue";
name: CanInvokeQueueName<TQueues>;
}
| {
kind: "request";
}
| {
kind: "websocket";
};

export type AnyCanInvokeTarget =
| {
kind: "action";
name: string;
}
| {
kind: "subscribe";
name: string;
}
| {
kind: "queue";
name: string;
}
| {
kind: "request";
}
| {
kind: "websocket";
};

interface BaseActorConfig<
TState,
TConnParams,
Expand Down Expand Up @@ -584,6 +640,29 @@ interface BaseActorConfig<
params: TConnParams,
) => void | Promise<void>;

/**
* Called before inbound invocations are processed.
*
* Return `true` to allow and `false` to deny.
* Returning any non-boolean value throws an error.
*
* This hook runs for inbound actions, queue sends, subscriptions,
* raw HTTP requests, and raw WebSocket connections.
*/
canInvoke?: (
c: ConnContext<
TState,
TConnParams,
TConnState,
TVars,
TInput,
TDatabase,
TEvents,
TQueues
>,
invoke: CanInvokeTarget<TActions, TEvents, TQueues>,
) => boolean | Promise<boolean>;

/**
* Called when a client successfully connects to the actor.
*
Expand Down Expand Up @@ -771,6 +850,7 @@ export type ActorConfig<
| "run"
| "onStateChange"
| "onBeforeConnect"
| "canInvoke"
| "onConnect"
| "onDisconnect"
| "onBeforeActionResponse"
Expand Down Expand Up @@ -877,6 +957,7 @@ export type ActorConfigInput<
| "run"
| "onStateChange"
| "onBeforeConnect"
| "canInvoke"
| "onConnect"
| "onDisconnect"
| "onBeforeActionResponse"
Expand Down Expand Up @@ -1175,6 +1256,12 @@ export const DocActorConfigSchema = z
.describe(
"Called before a client connects. Throw an error to reject the connection.",
),
canInvoke: z
.unknown()
.optional()
.describe(
"Called before inbound invocation entrypoints. Return true to allow or false to deny.",
),
onConnect: z
.unknown()
.optional()
Expand Down
11 changes: 11 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,17 @@ export class InvalidRequestHandlerResponse extends ActorError {
}
}

export class InvalidCanInvokeResponse extends ActorError {
constructor() {
super(
"handler",
"invalid_can_invoke_response",
"Actor's canInvoke hook must return a boolean value.",
);
this.statusCode = 500;
}
}

// Manager-specific errors
export class MissingActorHeader extends ActorError {
constructor() {
Expand Down
44 changes: 41 additions & 3 deletions rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ import {
CONN_VERSIONED,
} from "@/schemas/actor-persist/versioned";
import { EXTRA_ERROR_LOG } from "@/utils";
import { type ActorConfig, getRunFunction } from "../config";
import {
type AnyCanInvokeTarget,
type ActorConfig,
getRunFunction,
} from "../config";
import type { ConnDriver } from "../conn/driver";
import { createHttpDriver } from "../conn/drivers/http";
import {
Expand All @@ -32,8 +36,9 @@ import {
type PersistedConn,
} from "../conn/persisted";
import {
type ActionContext,
ActionContext,
ActorContext,
type ConnContext,
RequestContext,
WebSocketContext,
} from "../contexts";
Expand Down Expand Up @@ -609,13 +614,43 @@ export class ActorInstance<
});
}

async assertCanInvoke(
ctx: ConnContext<S, CP, CS, V, I, DB, E, Q>,
invoke: AnyCanInvokeTarget,
): Promise<void> {
const canInvoke = this.#config.canInvoke;
if (!canInvoke) {
return;
}

const result = await canInvoke(ctx, invoke);
if (typeof result !== "boolean") {
throw new errors.InvalidCanInvokeResponse();
}
if (!result) {
throw new errors.Forbidden();
}
}

async assertCanInvokeWebSocket(
conn: Conn<S, CP, CS, V, I, DB, E, Q>,
): Promise<void> {
await this.assertCanInvoke(new ActionContext(this, conn), {
kind: "websocket",
});
}

// MARK: - Action Execution
async executeAction(
ctx: ActionContext<S, CP, CS, V, I, DB, E, Q>,
actionName: string,
args: unknown[],
): Promise<unknown> {
this.assertReady();
await this.assertCanInvoke(ctx, {
kind: "action",
name: actionName,
});

const actions = this.#config.actions ?? {};
if (!(actionName in actions)) {
Expand Down Expand Up @@ -739,8 +774,11 @@ export class ActorInstance<
"rivet.conn.id": conn.id,
},
async () => {
const ctx = new RequestContext(this, conn, request);
try {
const ctx = new RequestContext(this, conn, request);
await this.assertCanInvoke(ctx, {
kind: "request",
});
const response = await onRequest(ctx, request);
if (!response) {
throw new errors.InvalidRequestHandlerResponse();
Expand Down
10 changes: 10 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ export async function processMessage<
});

if (subscribe) {
await actor.assertCanInvoke(
new ActionContext<S, CP, CS, V, I, DB, E, Q>(
actor,
conn,
),
{
kind: "subscribe",
name: eventName,
},
);
await handler.onSubscribe(eventName, conn);
} else {
await handler.onUnsubscribe(eventName, conn);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ export async function handleQueueSend(
status: "completed",
};
try {
const ctx = new ActionContext(actor, conn);
await actor.assertCanInvoke(ctx, {
kind: "queue",
name,
});

if (request.wait) {
result = await actor.queueManager.enqueueAndWait(
name,
Expand Down
Loading
Loading