diff --git a/.github/workflows/node.js.yml b/.github/workflows/node.js.yml index 62001b2..4282327 100644 --- a/.github/workflows/node.js.yml +++ b/.github/workflows/node.js.yml @@ -15,10 +15,10 @@ jobs: cache-dependency-path: | example/package.json package.json - node-version: "19.x" + node-version: "20.x" cache: "npm" - run: node setup.cjs - run: npm test - run: npm run typecheck - run: npm run lint - - run: npx pkg-pr-new publish + - run: npx pkg-pr-new publish ./ ./playground diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ceab35..87d7ecc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## 0.1.15 alpha + +- You can request that `syncStreams` return aborted streamed messages, + if you want to show those in your UI. +- They will have `msg.streaming === false` if they were aborted. +- Fix: stream deletion is idempotent and cleanup is canceled if it's already deleted. + ## 0.1.14 - Expose delete functions for messages & threads on the Agent class @@ -39,7 +46,7 @@ ## 0.1.9 -- You can finish a stream asynchronously and have it abort the streaming. +- You can abort a stream asynchronously and have it stop writing deltas smoothly. - The timeout for streaming deltas with no sign of life has been increased to 10 minutes. - Delete stream deltas automatically 5 min after the stream finishes. diff --git a/example/convex/_generated/api.d.ts b/example/convex/_generated/api.d.ts index e143b2a..8cd71e1 100644 --- a/example/convex/_generated/api.d.ts +++ b/example/convex/_generated/api.d.ts @@ -1615,6 +1615,12 @@ export declare const components: { >; }; streams: { + abort: FunctionReference< + "mutation", + "internal", + { reason: string; streamId: string }, + null + >; addDelta: FunctionReference< "mutation", "internal", @@ -1770,13 +1776,18 @@ export declare const components: { list: FunctionReference< "query", "internal", - { threadId: string }, + { + startOrder?: number; + statuses?: Array<"streaming" | "finished" | "aborted">; + threadId: string; + }, Array<{ agentName?: string; model?: string; order: number; provider?: string; providerOptions?: Record>; + status: "streaming" | "finished" | "aborted"; stepOrder: number; streamId: string; userId?: string; diff --git a/examples/chat-basic/convex/_generated/api.d.ts b/examples/chat-basic/convex/_generated/api.d.ts index 2b16d6d..c27828c 100644 --- a/examples/chat-basic/convex/_generated/api.d.ts +++ b/examples/chat-basic/convex/_generated/api.d.ts @@ -1589,6 +1589,12 @@ export declare const components: { >; }; streams: { + abort: FunctionReference< + "mutation", + "internal", + { reason: string; streamId: string }, + null + >; addDelta: FunctionReference< "mutation", "internal", @@ -1744,13 +1750,18 @@ export declare const components: { list: FunctionReference< "query", "internal", - { threadId: string }, + { + startOrder?: number; + statuses?: Array<"streaming" | "finished" | "aborted">; + threadId: string; + }, Array<{ agentName?: string; model?: string; order: number; provider?: string; providerOptions?: Record>; + status: "streaming" | "finished" | "aborted"; stepOrder: number; streamId: string; userId?: string; diff --git a/examples/chat-basic/package-lock.json b/examples/chat-basic/package-lock.json index 7f0e951..cf14e0f 100644 --- a/examples/chat-basic/package-lock.json +++ b/examples/chat-basic/package-lock.json @@ -3220,26 +3220,35 @@ "link": true }, "node_modules/convex-helpers": { - "version": "0.1.78", - "resolved": "https://registry.npmjs.org/convex-helpers/-/convex-helpers-0.1.78.tgz", - "integrity": "sha512-furrk3yEtmpuAenMUsLOa+IPTncYhUolziDsfsNE2Vo8QYqJlumQlj7X739w/0LyuRksA5zuNnnpJC2Ss6vkVw==", + "version": "0.1.94", + "resolved": "https://registry.npmjs.org/convex-helpers/-/convex-helpers-0.1.94.tgz", + "integrity": "sha512-35o9TzEUdze3wGHxksk9Ynutw8ekxp/kbjs1dlfK6iZwPJeKbv9sqRXIyvPytTGvMmoi40fGm09f/z22gg6L+A==", + "license": "Apache-2.0", "peer": true, "bin": { "convex-helpers": "bin.cjs" }, "peerDependencies": { + "@standard-schema/spec": "^1.0.0", "convex": "^1.13.0", "hono": "^4.0.5", "react": "^17.0.2 || ^18.0.0 || ^19.0.0", + "typescript": "^5.5", "zod": "^3.22.4" }, "peerDependenciesMeta": { + "@standard-schema/spec": { + "optional": true + }, "hono": { "optional": true }, "react": { "optional": true }, + "typescript": { + "optional": true + }, "zod": { "optional": true } @@ -7627,7 +7636,7 @@ "version": "5.7.3", "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.3.tgz", "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", - "dev": true, + "devOptional": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/examples/chat-streaming/convex/_generated/api.d.ts b/examples/chat-streaming/convex/_generated/api.d.ts index 519c4ac..00df529 100644 --- a/examples/chat-streaming/convex/_generated/api.d.ts +++ b/examples/chat-streaming/convex/_generated/api.d.ts @@ -1589,6 +1589,12 @@ export declare const components: { >; }; streams: { + abort: FunctionReference< + "mutation", + "internal", + { reason: string; streamId: string }, + null + >; addDelta: FunctionReference< "mutation", "internal", @@ -1744,13 +1750,18 @@ export declare const components: { list: FunctionReference< "query", "internal", - { threadId: string }, + { + startOrder?: number; + statuses?: Array<"streaming" | "finished" | "aborted">; + threadId: string; + }, Array<{ agentName?: string; model?: string; order: number; provider?: string; providerOptions?: Record>; + status: "streaming" | "finished" | "aborted"; stepOrder: number; streamId: string; userId?: string; diff --git a/examples/chat-streaming/package-lock.json b/examples/chat-streaming/package-lock.json index 7f0e951..cf14e0f 100644 --- a/examples/chat-streaming/package-lock.json +++ b/examples/chat-streaming/package-lock.json @@ -3220,26 +3220,35 @@ "link": true }, "node_modules/convex-helpers": { - "version": "0.1.78", - "resolved": "https://registry.npmjs.org/convex-helpers/-/convex-helpers-0.1.78.tgz", - "integrity": "sha512-furrk3yEtmpuAenMUsLOa+IPTncYhUolziDsfsNE2Vo8QYqJlumQlj7X739w/0LyuRksA5zuNnnpJC2Ss6vkVw==", + "version": "0.1.94", + "resolved": "https://registry.npmjs.org/convex-helpers/-/convex-helpers-0.1.94.tgz", + "integrity": "sha512-35o9TzEUdze3wGHxksk9Ynutw8ekxp/kbjs1dlfK6iZwPJeKbv9sqRXIyvPytTGvMmoi40fGm09f/z22gg6L+A==", + "license": "Apache-2.0", "peer": true, "bin": { "convex-helpers": "bin.cjs" }, "peerDependencies": { + "@standard-schema/spec": "^1.0.0", "convex": "^1.13.0", "hono": "^4.0.5", "react": "^17.0.2 || ^18.0.0 || ^19.0.0", + "typescript": "^5.5", "zod": "^3.22.4" }, "peerDependenciesMeta": { + "@standard-schema/spec": { + "optional": true + }, "hono": { "optional": true }, "react": { "optional": true }, + "typescript": { + "optional": true + }, "zod": { "optional": true } @@ -7627,7 +7636,7 @@ "version": "5.7.3", "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.3.tgz", "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", - "dev": true, + "devOptional": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/examples/files-images/convex/_generated/api.d.ts b/examples/files-images/convex/_generated/api.d.ts index 7bcc51a..534b580 100644 --- a/examples/files-images/convex/_generated/api.d.ts +++ b/examples/files-images/convex/_generated/api.d.ts @@ -1591,6 +1591,12 @@ export declare const components: { >; }; streams: { + abort: FunctionReference< + "mutation", + "internal", + { reason: string; streamId: string }, + null + >; addDelta: FunctionReference< "mutation", "internal", @@ -1746,13 +1752,18 @@ export declare const components: { list: FunctionReference< "query", "internal", - { threadId: string }, + { + startOrder?: number; + statuses?: Array<"streaming" | "finished" | "aborted">; + threadId: string; + }, Array<{ agentName?: string; model?: string; order: number; provider?: string; providerOptions?: Record>; + status: "streaming" | "finished" | "aborted"; stepOrder: number; streamId: string; userId?: string; diff --git a/examples/files-images/package-lock.json b/examples/files-images/package-lock.json index f982362..f270fd1 100644 --- a/examples/files-images/package-lock.json +++ b/examples/files-images/package-lock.json @@ -3227,26 +3227,35 @@ "link": true }, "node_modules/convex-helpers": { - "version": "0.1.78", - "resolved": "https://registry.npmjs.org/convex-helpers/-/convex-helpers-0.1.78.tgz", - "integrity": "sha512-furrk3yEtmpuAenMUsLOa+IPTncYhUolziDsfsNE2Vo8QYqJlumQlj7X739w/0LyuRksA5zuNnnpJC2Ss6vkVw==", + "version": "0.1.94", + "resolved": "https://registry.npmjs.org/convex-helpers/-/convex-helpers-0.1.94.tgz", + "integrity": "sha512-35o9TzEUdze3wGHxksk9Ynutw8ekxp/kbjs1dlfK6iZwPJeKbv9sqRXIyvPytTGvMmoi40fGm09f/z22gg6L+A==", + "license": "Apache-2.0", "peer": true, "bin": { "convex-helpers": "bin.cjs" }, "peerDependencies": { + "@standard-schema/spec": "^1.0.0", "convex": "^1.13.0", "hono": "^4.0.5", "react": "^17.0.2 || ^18.0.0 || ^19.0.0", + "typescript": "^5.5", "zod": "^3.22.4" }, "peerDependenciesMeta": { + "@standard-schema/spec": { + "optional": true + }, "hono": { "optional": true }, "react": { "optional": true }, + "typescript": { + "optional": true + }, "zod": { "optional": true } @@ -7659,7 +7668,7 @@ "version": "5.7.3", "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.3.tgz", "integrity": "sha512-84MVSjMEHP+FQRPy3pX9sTVV/INIex71s9TL2Gm5FG/WG1SqXeKyZ0k7/blY/4FdOzI12CBy1vGc4og/eus0fw==", - "dev": true, + "devOptional": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/examples/rate-limiting/convex/_generated/api.d.ts b/examples/rate-limiting/convex/_generated/api.d.ts index f268979..57eaff3 100644 --- a/examples/rate-limiting/convex/_generated/api.d.ts +++ b/examples/rate-limiting/convex/_generated/api.d.ts @@ -1589,6 +1589,12 @@ export declare const components: { >; }; streams: { + abort: FunctionReference< + "mutation", + "internal", + { reason: string; streamId: string }, + null + >; addDelta: FunctionReference< "mutation", "internal", @@ -1744,13 +1750,18 @@ export declare const components: { list: FunctionReference< "query", "internal", - { threadId: string }, + { + startOrder?: number; + statuses?: Array<"streaming" | "finished" | "aborted">; + threadId: string; + }, Array<{ agentName?: string; model?: string; order: number; provider?: string; providerOptions?: Record>; + status: "streaming" | "finished" | "aborted"; stepOrder: number; streamId: string; userId?: string; diff --git a/playground/src/definePlaygroundAPI.ts b/playground/src/definePlaygroundAPI.ts index 8b3b1ac..8139d7d 100644 --- a/playground/src/definePlaygroundAPI.ts +++ b/playground/src/definePlaygroundAPI.ts @@ -6,6 +6,7 @@ import { GenericDataModel, GenericQueryCtx, ApiFromModules, + GenericActionCtx, } from "convex/server"; import { vMessageDoc, @@ -25,33 +26,32 @@ export type PlaygroundAPI = ApiFromModules<{ playground: ReturnType; }>["playground"]; +export type AgentsFn = ( + ctx: GenericActionCtx | GenericQueryCtx, + args: { userId: string | undefined; threadId: string | undefined } +) => Promise[]>; + // Playground API definition -export function definePlaygroundAPI( +export function definePlaygroundAPI( component: AgentComponent, { - agents, + agents: agentsOrFn, userNameLookup, }: { - agents: Agent[]; - userNameLookup?: ( + agents: Agent[] | AgentsFn; + userNameLookup?: ( ctx: GenericQueryCtx, userId: string ) => string | Promise; } ) { - // Map agent name to instance - const agentMap: Record> = Object.fromEntries( - agents.map((agent, i) => [ - agent.options.name ?? `Agent ${i} (missing 'name')`, - agent, - ]) - ); - - for (const agent of agents) { - if (!agent.options.name) { - throw new Error( - `Agent has no name (instructions: ${agent.options.instructions})` - ); + function validateAgents(agents: Agent[]) { + for (const agent of agents) { + if (!agent.options.name) { + console.warn( + `Agent has no name (instructions: ${agent.options.instructions})` + ); + } } } @@ -74,14 +74,34 @@ export function definePlaygroundAPI( returns: v.boolean(), }); + async function getAgents( + ctx: GenericActionCtx | GenericQueryCtx, + args: { userId: string | undefined; threadId: string | undefined } + ) { + const agents = Array.isArray(agentsOrFn) + ? agentsOrFn + : await agentsOrFn(ctx, args); + validateAgents(agents); + return agents.map((agent, i) => ({ + name: agent.options.name ?? `Agent ${i} (missing 'name')`, + agent, + })); + } + // List all agents const listAgents = queryGeneric({ args: { apiKey: v.string(), + userId: v.optional(v.string()), + threadId: v.optional(v.string()), }, handler: async (ctx, args) => { + const agents = await getAgents(ctx, { + userId: args.userId, + threadId: args.threadId, + }); await validateApiKey(ctx, args.apiKey); - const agents = Object.entries(agentMap).map(([name, agent]) => ({ + return agents.map(({ name, agent }) => ({ name, instructions: agent.options.instructions, contextOptions: agent.options.contextOptions, @@ -90,7 +110,6 @@ export function definePlaygroundAPI( maxRetries: agent.options.maxRetries, tools: agent.options.tools ? Object.keys(agent.options.tools) : [], })); - return agents; }, }); @@ -194,21 +213,25 @@ export function definePlaygroundAPI( const createThread = mutationGeneric({ args: { apiKey: v.string(), - agentName: v.string(), userId: v.string(), title: v.optional(v.string()), summary: v.optional(v.string()), + /** @deprecated Unused. */ + agentName: v.optional(v.string()), }, handler: async (ctx, args) => { + // if (args.agentName) { + // console.warn( + // "Upgrade to the latest version of @convex-dev/agent-playground" + // ); + // } await validateApiKey(ctx, args.apiKey); - const agent = agentMap[args.agentName]; - if (!agent) throw new Error(`Unknown agent: ${args.agentName}`); - const { threadId } = await agent.createThread(ctx, { + const { _id } = await ctx.runMutation(component.threads.createThread, { userId: args.userId, title: args.title, summary: args.summary, }); - return { threadId }; + return { threadId: _id }; }, returns: v.object({ threadId: v.string() }), }); @@ -228,7 +251,7 @@ export function definePlaygroundAPI( messages: v.optional(v.array(vMessage)), system: v.optional(v.string()), }, - handler: async (ctx, args) => { + handler: async (ctx: GenericActionCtx, args) => { const { apiKey, agentName, @@ -240,8 +263,13 @@ export function definePlaygroundAPI( ...rest } = args; await validateApiKey(ctx, apiKey); - const agent = agentMap[agentName]; - if (!agent) throw new Error(`Unknown agent: ${agentName}`); + const agents = await getAgents(ctx, { + userId: args.userId, + threadId: args.threadId, + }); + const namedAgent = agents.find(({ name }) => name === args.agentName); + if (!namedAgent) throw new Error(`Unknown agent: ${args.agentName}`); + const { agent } = namedAgent; const { thread } = await agent.continueThread(ctx, { threadId, userId }); const { messageId, text } = await thread.generateText( { ...rest, ...(system ? { system } : {}) }, @@ -267,8 +295,13 @@ export function definePlaygroundAPI( }, handler: async (ctx, args) => { await validateApiKey(ctx, args.apiKey); - const agent = agentMap[args.agentName]; - if (!agent) throw new Error(`Unknown agent: ${args.agentName}`); + const agents = await getAgents(ctx, { + userId: args.userId, + threadId: args.threadId, + }); + const namedAgent = agents.find(({ name }) => name === args.agentName); + if (!namedAgent) throw new Error(`Unknown agent: ${args.agentName}`); + const { agent } = namedAgent; const contextOptions = args.contextOptions ?? agent.options.contextOptions ?? diff --git a/playground/src/pages/Play.tsx b/playground/src/pages/Play.tsx index c49a497..1456da9 100644 --- a/playground/src/pages/Play.tsx +++ b/playground/src/pages/Play.tsx @@ -66,7 +66,11 @@ const Play = ({ apiKey, api }: PlayProps) => { } }, [messages.results, selectedMessageId]); - const agents = useQuery(api.listAgents, { apiKey }); + const agents = useQuery(api.listAgents, { + apiKey, + threadId: selectedThreadId, + userId: selectedUserId, + }); useEffect(() => { if (agents && agents.length > 0 && !selectedAgent) { setSelectedAgent(agents[0]); @@ -76,6 +80,18 @@ const Play = ({ apiKey, api }: PlayProps) => { if (agents[0].storageOptions) { setStorageOptions(agents[0].storageOptions); } + } else if (agents && selectedAgent) { + const newAgent = agents.find( + (agent) => agent.name === selectedAgent.name + ); + if (newAgent) { + if (JSON.stringify(selectedAgent) !== JSON.stringify(newAgent)) { + setSelectedAgent(newAgent); + } + } else { + // The selected agent is no longer in the list of agents, so clear it + setSelectedAgent(undefined); + } } }, [agents, selectedAgent]); @@ -106,7 +122,11 @@ const Play = ({ apiKey, api }: PlayProps) => { const handleSelectMessage = (messageId: string) => { setSelectedMessageId(messageId); const message = messages.results.find((m) => m._id === messageId); - if (message && message.agentName && selectedAgent !== message.agentName) { + if ( + message && + message.agentName && + selectedAgent?.name !== message.agentName + ) { const agent = agents?.find((a) => a.name === message.agentName); if (agent) { setSelectedAgent(agent); @@ -127,7 +147,7 @@ const Play = ({ apiKey, api }: PlayProps) => { agentName: selectedAgent.name, threadId: selectedThreadId, userId: selectedUserId, - messages: [selectedMessage.message], + messages: selectedMessage.message ? [selectedMessage.message] : [], contextOptions, beforeMessageId: selectedMessage._id, }); @@ -214,7 +234,7 @@ const Play = ({ apiKey, api }: PlayProps) => {
{ +export class Agent { constructor( public component: AgentComponent, public options: { @@ -598,6 +598,7 @@ export class Agent { error: (error.error as Error).message, }); } + // TODO: update the streamer to error state return args.onError?.(error); }, onStepFinish: async (step) => { @@ -611,7 +612,6 @@ export class Agent { promptMessageId: messageId, step, }); - // TODO: figure out pending/not await streamer?.finish(saved.messages); } if (this.options.rawRequestResponseHandler) { @@ -1027,6 +1027,8 @@ export class Agent { args: { threadId: string; streamArgs: StreamArgs | undefined; + // By default, only streaming messages are included. + includeStatuses?: ("streaming" | "finished" | "aborted")[]; } ): Promise { if (!args.streamArgs) return undefined; @@ -1035,6 +1037,8 @@ export class Agent { kind: "list", messages: await ctx.runQuery(this.component.streams.list, { threadId: args.threadId, + startOrder: args.streamArgs.startOrder, + statuses: args.includeStatuses, }), }; } else { diff --git a/src/client/streaming.ts b/src/client/streaming.ts index b72bb53..b464f41 100644 --- a/src/client/streaming.ts +++ b/src/client/streaming.ts @@ -96,7 +96,13 @@ export class DeltaStreamer { this.#nextStepOrder = (metadata.stepOrder ?? 0) + 1; this.abortController = new AbortController(); if (metadata.abortSignal) { - metadata.abortSignal.addEventListener("abort", () => { + metadata.abortSignal.addEventListener("abort", async () => { + if (this.streamId) { + await this.ctx.runMutation(this.component.streams.abort, { + streamId: this.streamId, + reason: "abortSignal", + }); + } this.abortController.abort(); }); } diff --git a/src/component/_generated/api.d.ts b/src/component/_generated/api.d.ts index 88fd8ea..6f68b01 100644 --- a/src/component/_generated/api.d.ts +++ b/src/component/_generated/api.d.ts @@ -1433,6 +1433,12 @@ export type Mounts = { >; }; streams: { + abort: FunctionReference< + "mutation", + "public", + { reason: string; streamId: string }, + null + >; addDelta: FunctionReference< "mutation", "public", @@ -1588,13 +1594,18 @@ export type Mounts = { list: FunctionReference< "query", "public", - { threadId: string }, + { + startOrder?: number; + statuses?: Array<"streaming" | "finished" | "aborted">; + threadId: string; + }, Array<{ agentName?: string; model?: string; order: number; provider?: string; providerOptions?: Record>; + status: "streaming" | "finished" | "aborted"; stepOrder: number; streamId: string; userId?: string; diff --git a/src/component/schema.ts b/src/component/schema.ts index ecf077f..a03970a 100644 --- a/src/component/schema.ts +++ b/src/component/schema.ts @@ -113,10 +113,11 @@ export const schema = defineSchema({ v.object({ kind: v.literal("finished"), endedAt: v.number(), + cleanupFnId: v.optional(v.id("_scheduled_functions")), }), v.object({ - kind: v.literal("error"), - error: v.string(), + kind: v.literal("aborted"), + reason: v.string(), }) ), }) diff --git a/src/component/streams.ts b/src/component/streams.ts index 096afb1..1fe10a6 100644 --- a/src/component/streams.ts +++ b/src/component/streams.ts @@ -6,7 +6,7 @@ import { vStreamMessage, } from "../validators.js"; import { api, internal } from "./_generated/api.js"; -import type { Id } from "./_generated/dataModel.js"; +import type { Doc, Id } from "./_generated/dataModel.js"; import { internalMutation, mutation, @@ -106,33 +106,87 @@ export const create = mutation({ export const list = query({ args: { threadId: v.id("threads"), + startOrder: v.optional(v.number()), + statuses: v.optional( + v.array( + v.union( + v.literal("streaming"), + v.literal("finished"), + v.literal("aborted") + ) + ) + ), }, returns: v.array(vStreamMessage), handler: async (ctx, args) => { - return ctx.db - .query("streamingMessages") - .withIndex("threadId_state_order_stepOrder", (q) => - q.eq("threadId", args.threadId).eq("state.kind", "streaming") - ) - .order("desc") - .take(100) - .then((msgs) => - msgs.map((m) => ({ - streamId: m._id, - ...pick(m, [ - "order", - "stepOrder", - "userId", - "agentName", - "model", - "provider", - "providerOptions", - ]), - })) + const statuses = args.statuses ?? ["streaming"]; + const messages = await mergedStream( + statuses.map((status) => + stream(ctx.db, schema) + .query("streamingMessages") + .withIndex("threadId_state_order_stepOrder", (q) => + q + .eq("threadId", args.threadId) + .eq("state.kind", status) + .gte("order", args.startOrder ?? 0) + ) + .order("desc") + ), + ["order", "stepOrder"] + ).take(100); + + return messages.map((m) => ({ + streamId: m._id, + status: m.state.kind, + ...pick(m, [ + "order", + "stepOrder", + "userId", + "agentName", + "model", + "provider", + "providerOptions", + ]), + })); + }, +}); + +export const abort = mutation({ + args: { + streamId: v.id("streamingMessages"), + reason: v.string(), + }, + returns: v.null(), + handler: async (ctx, args) => { + const stream = await ctx.db.get(args.streamId); + if (!stream) { + throw new Error(`Stream not found: ${args.streamId}`); + } + if (stream.state.kind !== "streaming") { + console.warn( + `Stream trying to abort but not currently streaming (${stream.state.kind}): ${args.streamId}` ); + return; + } + await cleanupTimeoutFn(ctx, stream); + await ctx.db.patch(args.streamId, { + state: { kind: "aborted", reason: args.reason }, + }); }, }); +async function cleanupTimeoutFn( + ctx: MutationCtx, + stream: Doc<"streamingMessages"> +) { + if (stream.state.kind === "streaming" && stream.state.timeoutFnId) { + const timeoutFn = await ctx.db.system.get(stream.state.timeoutFnId); + if (timeoutFn?.state.kind === "pending") { + await ctx.scheduler.cancel(stream.state.timeoutFnId); + } + } +} + export const finish = mutation({ args: { streamId: v.id("streamingMessages"), @@ -153,20 +207,15 @@ export const finish = mutation({ ); return; } - if (stream.state.timeoutFnId) { - const timeoutFn = await ctx.db.system.get(stream.state.timeoutFnId); - if (timeoutFn?.state.kind === "pending") { - await ctx.scheduler.cancel(stream.state.timeoutFnId); - } - } - await ctx.db.patch(args.streamId, { - state: { kind: "finished", endedAt: Date.now() }, - }); - await ctx.scheduler.runAfter( + await cleanupTimeoutFn(ctx, stream); + const cleanupFnId = await ctx.scheduler.runAfter( DELETE_STREAM_DELAY, api.streams.deleteStreamAsync, { streamId: args.streamId } ); + await ctx.db.patch(args.streamId, { + state: { kind: "finished", endedAt: Date.now(), cleanupFnId }, + }); }, }); @@ -223,8 +272,8 @@ export const timeoutStream = internalMutation({ } await ctx.db.patch(args.streamId, { state: { - kind: "finished", - endedAt: Date.now(), + kind: "aborted", + reason: "timeout", }, }); }, @@ -245,12 +294,9 @@ async function deletePageForStreamId( if (deltas.isDone) { const stream = await ctx.db.get(args.streamId); if (stream) { - const state = stream.state; - if (state.kind === "streaming" && state.timeoutFnId) { - const timeoutFn = await ctx.db.system.get(state.timeoutFnId); - if (timeoutFn?.state.kind === "pending") { - await ctx.scheduler.cancel(state.timeoutFnId); - } + await cleanupTimeoutFn(ctx, stream); + if (stream.state.kind === "finished" && stream.state.cleanupFnId) { + await ctx.scheduler.cancel(stream.state.cleanupFnId); } await ctx.db.delete(args.streamId); } diff --git a/src/react/deltas.ts b/src/react/deltas.ts index 22817b2..1b453ea 100644 --- a/src/react/deltas.ts +++ b/src/react/deltas.ts @@ -256,12 +256,14 @@ export function createStreamingMessage( ): MessageDoc { const { streamId, ...rest } = message; const metadata: MessageDoc = { + ...rest, _id: `${streamId}-${index}`, _creationTime: Date.now(), - status: "pending", + status: ( + { streaming: "pending", finished: "success", aborted: "failed" } as const + )[message.status], threadId, tool: false, - ...rest, }; switch (part.type) { case "text-delta": diff --git a/src/react/index.ts b/src/react/index.ts index f432856..f4b5239 100644 --- a/src/react/index.ts +++ b/src/react/index.ts @@ -1,5 +1,5 @@ "use client"; -import type { ErrorMessage } from "convex-helpers"; +import { omit, type ErrorMessage } from "convex-helpers"; import { type PaginatedQueryArgs, type UsePaginatedQueryResult, @@ -7,7 +7,7 @@ import { } from "convex/react"; import { usePaginatedQuery } from "convex-helpers/react"; import type { FunctionArgs } from "convex/server"; -import { useMemo, useState } from "react"; +import { useMemo, useRef, useState } from "react"; import type { MessageDoc } from "../client/index.js"; import type { SyncStreamsReturnValue } from "../client/types.js"; import type { StreamArgs } from "../validators.js"; @@ -103,14 +103,18 @@ export function useThreadMessages< ThreadMessagesArgs, ThreadMessagesResult >, - !options.stream ? "skip" : args + !options.stream || + args === "skip" || + paginated.status === "LoadingFirstPage" + ? "skip" + : { ...args, startOrder: paginated.results.at(-1)?.order } ); const merged = useMemo(() => { const streamListMessages = streamMessages?.map((m) => ({ ...m, - streaming: true, + streaming: !m.status || m.status === "streaming", })) ?? []; return { ...paginated, @@ -150,7 +154,7 @@ export function useStreamingThreadMessages< Query extends ThreadStreamQuery, >( query: Query, - args: ThreadMessagesArgs | "skip" + args: (ThreadMessagesArgs & { startOrder?: number }) | "skip" ): Array> | undefined { // Invariant: streamMessages[streamId] is comprised of all deltas up to the // cursor. There can be multiple messages in the same stream, e.g. for tool @@ -158,15 +162,23 @@ export function useStreamingThreadMessages< const [streams, setStreams] = useState< Array<{ streamId: string; cursor: number; messages: MessageDoc[] }> >([]); + const startOrderRef = useRef(0); + const queryArgs = args === "skip" ? args : omit(args, ["startOrder"]); + if (args !== "skip" && !startOrderRef.current && args.startOrder) { + startOrderRef.current = args.startOrder; + } // Get all the active streams const streamList = useQuery( query, - args === "skip" - ? args + queryArgs === "skip" + ? queryArgs : ({ - ...args, + ...queryArgs, paginationOpts: { cursor: null, numItems: 0 }, - streamArgs: { kind: "list" } as StreamArgs, + streamArgs: { + kind: "list", + startOrder: startOrderRef.current, + } as StreamArgs, } as FunctionArgs) ) as | { streams: Extract } @@ -186,10 +198,10 @@ export function useStreamingThreadMessages< // Get the deltas for all the active streams, if any. const cursorQuery = useQuery( query, - args === "skip" || !streamList + queryArgs === "skip" || !streamList ? ("skip" as const) : ({ - ...args, + ...queryArgs, paginationOpts: { cursor: null, numItems: 0 }, streamArgs: { kind: "deltas", cursors } as StreamArgs, } as FunctionArgs) diff --git a/src/react/optimisticallySendMessage.ts b/src/react/optimisticallySendMessage.ts index ea64fe2..de76b32 100644 --- a/src/react/optimisticallySendMessage.ts +++ b/src/react/optimisticallySendMessage.ts @@ -11,7 +11,7 @@ export function optimisticallySendMessage( ) => void { return (store, args) => { const queries = store.getAllQueries(query); - let maxOrder = 0; + let maxOrder = -1; let maxStepOrder = 0; for (const q of queries) { if (q.args?.threadId !== args.threadId) continue; diff --git a/src/validators.ts b/src/validators.ts index 7e95067..8df5d2a 100644 --- a/src/validators.ts +++ b/src/validators.ts @@ -460,6 +460,7 @@ export const vStreamArgs = v.optional( v.union( v.object({ kind: v.literal("list"), + startOrder: v.optional(v.number()), }), v.object({ kind: v.literal("deltas"), @@ -471,6 +472,11 @@ export type StreamArgs = Infer; export const vStreamMessage = v.object({ streamId: v.string(), + status: v.union( + v.literal("streaming"), + v.literal("finished"), + v.literal("aborted") + ), order: v.number(), stepOrder: v.number(), // metadata