diff --git a/bun.lock b/bun.lock index 9745ed2..d90443f 100644 --- a/bun.lock +++ b/bun.lock @@ -7,11 +7,7 @@ }, "packages/fluent-ai": { "name": "fluent-ai", - "version": "0.4.4", - "dependencies": { - "eventsource-parser": "^3.0.6", - "partial-json": "^0.1.7", - }, + "version": "0.4.7", "devDependencies": { "@types/bun": "1.3.0", "bun-plugin-dts": "^0.3.0", @@ -504,8 +500,6 @@ "etag": ["etag@1.8.1", "", {}, "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg=="], - "eventsource-parser": ["eventsource-parser@3.0.6", "", {}, "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg=="], - "exit-hook": ["exit-hook@2.2.1", "", {}, "sha512-eNTPlAD67BmP31LDINZ3U7HSF8l57TxOY2PmBJ1shpCvpnxBF93mWCE8YHBnXs8qiUZJc9WDcWIeC3a2HIAMfw=="], "express": ["express@4.21.2", "", { "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "etag": "~1.8.1", "finalhandler": "1.3.1", "fresh": "0.5.2", "http-errors": "2.0.0", "merge-descriptors": "1.0.3", "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", "safe-buffer": "5.2.1", "send": "0.19.0", "serve-static": "1.16.2", "setprototypeof": "1.2.0", "statuses": "2.0.1", "type-is": "~1.6.18", "utils-merge": "1.0.1", "vary": "~1.1.2" } }, "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA=="], @@ -664,8 +658,6 @@ "parseurl": ["parseurl@1.3.3", "", {}, "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ=="], - "partial-json": ["partial-json@0.1.7", "", {}, "sha512-Njv/59hHaokb/hRUjce3Hdv12wd60MtM9Z5Olmn+nehe0QDAsRtRbJPvJ0Z91TusF0SuZRIvnM+S4l6EIP8leA=="], - "path-key": ["path-key@3.1.1", "", {}, "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q=="], "path-scurry": ["path-scurry@1.11.1", "", { "dependencies": { "lru-cache": "^10.2.0", "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" } }, "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA=="], diff --git a/packages/fluent-ai/.gitignore b/packages/fluent-ai/.gitignore new file mode 100644 index 0000000..9b1c8b1 --- /dev/null +++ b/packages/fluent-ai/.gitignore @@ -0,0 +1 @@ +/dist diff --git a/packages/fluent-ai/examples/ollama-chat.ts b/packages/fluent-ai/examples/ollama-chat.ts index d8d945e..130bc83 100644 --- a/packages/fluent-ai/examples/ollama-chat.ts +++ b/packages/fluent-ai/examples/ollama-chat.ts @@ -5,13 +5,30 @@ const models = await ollama().models().run(); console.log(models); const response = await ollama() - .chat(models[0].name) + .chat(models[0].id) .messages([ { role: "user", - content: "What is the capital of France?", + text: "What is the capital of France?", }, ]) .run(); console.log(response); + +const streamResponse = await ollama() + .chat(models[0].id) + .messages([ + { + role: "user", + text: "What is the capital of Spain?", + }, + ]) + .stream() + .run(); + +for await (const chunk of streamResponse) { + if (chunk.message?.text) { + process.stdout.write(chunk.message.text); + } +} diff --git a/packages/fluent-ai/examples/openrouter-chat.ts b/packages/fluent-ai/examples/openrouter-chat.ts index ebdb4bc..7941652 100644 --- a/packages/fluent-ai/examples/openrouter-chat.ts +++ b/packages/fluent-ai/examples/openrouter-chat.ts @@ -7,8 +7,8 @@ const job: Job = { input: { model: "google/gemini-2.5-flash", messages: [ - { role: "system", content: "You are a helpful assistant." }, - { role: "user", content: "Hi" }, + { role: "system", text: "You are a helpful assistant." }, + { role: "user", text: "Hi" }, ], }, }; diff --git a/packages/fluent-ai/package.json b/packages/fluent-ai/package.json index 7816f07..59fafaa 100644 --- a/packages/fluent-ai/package.json +++ b/packages/fluent-ai/package.json @@ -19,10 +19,6 @@ "build": "bun run build.ts", "prepublishOnly": "rm -rf dist && bun run build" }, - "dependencies": { - "eventsource-parser": "^3.0.6", - "partial-json": "^0.1.7" - }, "keywords": [ "ai", "openai", diff --git a/packages/fluent-ai/src/agent/agent.ts b/packages/fluent-ai/src/agent/agent.ts index 47ec83c..0036254 100644 --- a/packages/fluent-ai/src/agent/agent.ts +++ b/packages/fluent-ai/src/agent/agent.ts @@ -1,11 +1,15 @@ import { z } from "zod"; -import { convertMessagesForChatCompletion } from "~/src/agent/message"; import { agentToolSchema, type AgentToolBuilder, type AgentTool, } from "~/src/agent/tool"; -import type { Message } from "~/src/job/schema"; +import type { + Message, + ToolMessage, + MessageChunk, + AssistantMessage, +} from "~/src/job/schema"; import type { ChatBuilder } from "~/src/builder/chat"; export const agentSchema = z.object({ @@ -14,7 +18,30 @@ export const agentSchema = z.object({ tools: z.array(agentToolSchema), }); -interface GenerateOptions { +interface ChunkEvent { + type: "chunk"; + chunk: { + text?: string; + reasoning?: string; + }; +} + +interface ToolEvent { + type: "tool"; + tool: { + name: string; + args: any; + result?: any; + error?: any; + }; +} + +interface MessageEvent { + type: "message"; + message: Message; +} + +export interface AgentGenerateOptions { maxSteps: number; } @@ -46,7 +73,7 @@ export class Agent { generate = async function* ( this: Agent, initialMessages: Message[], - options: GenerateOptions, + options: AgentGenerateOptions, context?: TContext, ) { const body = agentSchema.parse(this.body); @@ -62,12 +89,11 @@ export class Agent { typeof body.instructions === "function" ? body.instructions() : body.instructions; - const allMessages = initialMessages.concat(newMessages); - const convertedMessages = convertMessagesForChatCompletion(allMessages); - const messages = [{ role: "system", content: instructions }].concat( - convertedMessages as any, + const systemMessage = { role: "system", text: instructions }; + const messages = ([systemMessage] as Message[]).concat( + initialMessages, + newMessages, ); - // TODO: agent tool vs chat tool const tools = body.tools.map((tool) => ({ name: tool.name, description: tool.description, @@ -79,79 +105,67 @@ export class Agent { .run(); let totalText = ""; - for await (const chunk of result) { - const delta = chunk.raw.choices[0].delta; - - // TODO: tool calls with content?? - if (delta.tool_calls) { - // TODO: tool call with content - // TODO: tool call with input streaming - // TODO: support multiple tool calls - const toolCall = delta.tool_calls[0]; - const toolName = toolCall.function.name; - const input = JSON.parse(toolCall.function.arguments); // TODO: parsing error handling - - const agentTool = body.tools.find((t) => t.name === toolName); + for await (const chunk of result as AsyncIterable) { + if (chunk.toolCalls) { + const toolCall = chunk.toolCalls[0]; + const { name, arguments: args } = toolCall.function; + const agentTool = body.tools.find((t) => t.name === name); if (!agentTool) { - throw new Error(`Unknown tool: ${toolName}`); + throw new Error(`Unknown tool: ${name}`); } - const toolPart = { - type: "tool-" + toolName, - toolCallId: toolCall.id, - input: input, - }; - - yield { type: "tool-call-input", data: toolPart }; - - let output = null; - let outputError = null; + yield { type: "tool", tool: { name, args } }; + let result = null; + let error = null; try { - output = await agentTool.execute(input, context!); + result = await agentTool.execute(args, context!); } catch (err) { - outputError = (err as Error).message; + error = (err as Error).message; } - if (outputError) { - yield { - type: "tool-call-output", - data: { ...toolPart, outputError }, - }; - } else { - yield { type: "tool-call-output", data: { ...toolPart, output } }; - } + yield { + type: "tool", + tool: { name, args, result, error }, + } as ToolEvent; - const newMessage: Message = { + const newMessage: ToolMessage = { role: "tool", - parts: [ - { - type: `tool-${toolName}`, - toolCallId: toolCall.id, - input: input, - output: output, - outputError: outputError, - }, - ], + text: "", + content: { + callId: toolCall.id, + name: name, + args: args, + result: result, + error: error, + }, }; - yield { type: "message-created", data: newMessage }; newMessages.push(newMessage); - } else if (delta.content) { - const text = delta.content as string; - yield { type: "text-delta", data: { text } }; - totalText += text; + shouldBreak = false; + } else if (chunk.text || chunk.reasoning) { + yield { + type: "chunk", + chunk: { + text: chunk.text, + reasoning: chunk.reasoning, + }, + } as ChunkEvent; + + if (chunk.text) { + totalText += chunk.text; + } shouldBreak = true; } } if (totalText.trim()) { - const newMessage: Message = { + const newMessage: AssistantMessage = { role: "assistant", - parts: [{ type: "text", text: totalText.trim() }], + text: totalText, }; - yield { type: "message-created", data: newMessage }; + yield { type: "message", message: newMessage } as MessageEvent; newMessages.push(newMessage); shouldBreak = true; } diff --git a/packages/fluent-ai/src/agent/message.ts b/packages/fluent-ai/src/agent/message.ts deleted file mode 100644 index 3f9211b..0000000 --- a/packages/fluent-ai/src/agent/message.ts +++ /dev/null @@ -1,46 +0,0 @@ -import type { Message } from "~/src/job/schema"; - -export function convertMessagesForChatCompletion(messages: Message[]) { - let result = []; - - for (const message of messages) { - if (message.role === "tool") { - const part = message.parts[0]; // TODO: support multiple parts - result.push({ - role: "assistant", - content: null, - tool_calls: [ - { - id: part.toolCallId, - type: "function", - function: { - name: part.type.substring("tool-".length), - arguments: JSON.stringify(part.input), - }, - }, - ], - }); - - if (part.outputError) { - result.push({ - role: "tool", - tool_call_id: part.toolCallId, - content: JSON.stringify(part.outputError), - }); - } else if (part.output) { - result.push({ - role: "tool", - tool_call_id: part.toolCallId, - content: JSON.stringify(part.output), - }); - } - } else { - result.push({ - role: message.role, - content: message.parts, - }); - } - } - - return result; -} diff --git a/packages/fluent-ai/src/agent/repl.ts b/packages/fluent-ai/src/agent/repl.ts index bef560f..22ffbbb 100644 --- a/packages/fluent-ai/src/agent/repl.ts +++ b/packages/fluent-ai/src/agent/repl.ts @@ -1,26 +1,49 @@ import * as readline from "node:readline/promises"; import type { Agent } from "~/src/agent/agent"; -import type { Message, MessagePart } from "~/src/job/schema"; +import type { Message, UserMessage } from "~/src/job/schema"; -export async function agentReplInput(): Promise { +function newId() { + return String(Math.floor(Math.random() * 1_000_000_000)); +} + +export async function agentReplInput(): Promise { const rl = readline.createInterface({ input: process.stdin, output: process.stdout, }); - const userInput = await rl.question("\nYou: "); + const input = await rl.question("\nYou: "); rl.close(); - const newMessage = { - role: "user" as const, - parts: [ - { - type: "text", - text: userInput, - }, - ], - }; + return { id: newId(), role: "user", text: input }; +} - return newMessage; +export async function inspectAgentStream(stream: AsyncIterable) { + const newMessages = []; + for await (const event of stream) { + if (event.type === "chunk") { + if (event.chunk.reasoning) { + process.stdout.write("\x1b[90m"); + process.stdout.write(event.chunk.reasoning); + } else if (event.chunk.text) { + process.stdout.write("\x1b[0m"); + process.stdout.write(event.chunk.text); + } + } else if (event.type === "tool") { + process.stdout.write("\x1b[0m"); + if (event.tool.result) { + console.log( + `tool ${event.tool.name} with result: ${JSON.stringify(event.tool.result)}`, + ); + } else { + console.log( + `tool ${event.tool.name} with arguments: ${JSON.stringify(event.tool.args)}`, + ); + } + } else if (event.type === "message") { + newMessages.push(event.message); + } + } + return newMessages; } export async function agentRepl(agent: Agent) { @@ -29,24 +52,9 @@ export async function agentRepl(agent: Agent) { const userMessage = await agentReplInput(); allMessages = allMessages.concat([userMessage]); const stream = agent.generate(allMessages, { maxSteps: 8 }); - for await (const event of stream) { - if (event.type === "text-delta") { - process.stdout.write((event.data as any).text); - } else if (event.type === "tool-call-input") { - const toolPart = event.data as MessagePart; - console.log( - `\n[Calling tool ${toolPart.type} with input: ${JSON.stringify(toolPart.input)}]`, - ); - } else if (event.type === "tool-call-output") { - const toolPart = event.data as MessagePart; - console.log( - `\n[Calling tool ${toolPart.type} with output: ${JSON.stringify(toolPart.output || toolPart.outputError)}]`, - ); - } else if (event.type === "message-created") { - const data = event.data as Message; - allMessages = allMessages.concat([data]); - } - } + + const newMessages = await inspectAgentStream(stream); + allMessages = allMessages.concat(newMessages); console.log(); } } diff --git a/packages/fluent-ai/src/builder/chat.ts b/packages/fluent-ai/src/builder/chat.ts index 1771085..1e3a56f 100644 --- a/packages/fluent-ai/src/builder/chat.ts +++ b/packages/fluent-ai/src/builder/chat.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { type ChatJob, type ChatTool } from "~/src/job/schema"; +import { type ChatJob, type ChatTool, type Message } from "~/src/job/schema"; export class ChatBuilder { private provider: TProvider; @@ -14,12 +14,12 @@ export class ChatBuilder { this.input.model = model; } - messages(messages: Array<{ role: string; content: string }>): this { + messages(messages: Message[]) { this.input.messages = messages; return this; } - tool(tool: ChatTool): this { + tool(tool: ChatTool) { if (!this.input.tools) { this.input.tools = []; } @@ -27,7 +27,7 @@ export class ChatBuilder { return this; } - tools(tools: ChatTool[]): this { + tools(tools: ChatTool[]) { if (!this.input.tools) { this.input.tools = []; } @@ -55,35 +55,6 @@ export class ChatBuilder { } } -export function user(content: string) { - return { role: "user", content: content }; -} - -export function assistant(content: string) { - return { role: "assistant", content: content }; -} - -export function system(content: string) { - return { role: "system", content: content }; -} - -export function text(result: any) { - if (result.raw) { - if (result.raw.candidates) { - return result.raw.candidates[0].content.parts[0].text; - } - - if (result.raw.choices[0].message) { - return result.raw.choices[0].message.content; - } - - if (result.raw.choices[0].delta.content) { - return result.raw.choices[0].delta.content; - } - } - return ""; -} - class ChatToolBuilder { private body: Partial = {}; diff --git a/packages/fluent-ai/src/builder/ollama.ts b/packages/fluent-ai/src/builder/ollama.ts index 09389a4..651e686 100644 --- a/packages/fluent-ai/src/builder/ollama.ts +++ b/packages/fluent-ai/src/builder/ollama.ts @@ -1,6 +1,7 @@ import type { Options } from "~/src/job/schema"; import { ChatBuilder } from "~/src/builder/chat"; import { ModelsBuilder } from "~/src/builder/models"; +import { EmbeddingBuilder } from "~/src/builder/embedding"; import { runner } from "~/src/job/ollama"; export function ollama(options?: Options) { @@ -11,5 +12,8 @@ export function ollama(options?: Options) { models() { return new ModelsBuilder("ollama" as const, options, runner); }, + embedding(model: string) { + return new EmbeddingBuilder("ollama" as const, options, runner, model); + }, }; } diff --git a/packages/fluent-ai/src/job/http.ts b/packages/fluent-ai/src/job/http.ts index 3f1e8f4..27ff5ea 100644 --- a/packages/fluent-ai/src/job/http.ts +++ b/packages/fluent-ai/src/job/http.ts @@ -8,6 +8,14 @@ export async function createHTTPJob( ): Promise { try { const response = await fetch(request); + + if (!response.ok) { + console.error("HTTP Error Response:", await response.text()); + throw new Error( + `Server error: ${response.status} ${response.statusText}`, + ); + } + return await handleResponse(response); } catch (error) { if (error instanceof Error) { diff --git a/packages/fluent-ai/src/job/ollama.ts b/packages/fluent-ai/src/job/ollama.ts index 573f1f8..90cec6b 100644 --- a/packages/fluent-ai/src/job/ollama.ts +++ b/packages/fluent-ai/src/job/ollama.ts @@ -1,9 +1,12 @@ -import type { ChatJob, ModelsJob } from "~/src/job/schema"; +import type { + ChatJob, + ModelsJob, + EmbeddingJob, + Message, + ChatTool, +} from "~/src/job/schema"; import { createHTTPJob } from "~/src/job/http"; -import { - transformToolsToFunctions, - createStreamingGenerator, -} from "~/src/job/utils"; +import { createStreamingGenerator } from "~/src/job/utils"; const DEFAULT_BASE_URL = "http://localhost:11434"; @@ -11,10 +14,97 @@ function getBaseUrl(options?: ChatJob["options"]): string { return options?.baseUrl || process.env.OLLAMA_BASE_URL || DEFAULT_BASE_URL; } +function convertMessages(messages: Message[]) { + let result = []; + + for (const message of messages) { + if (message.role === "tool") { + result.push({ + role: "assistant", + content: null, + tool_calls: [ + { + id: message.content.callId, + type: "function", + function: { + name: message.content.name, + arguments: message.content.args, + }, + }, + ], + }); + + if (message.content.result) { + result.push({ + role: "tool", + tool_call_id: message.content.callId, + content: JSON.stringify(message.content.result), + }); + } else if (message.content.error) { + result.push({ + role: "tool", + tool_call_id: message.content.callId, + content: JSON.stringify(message.content.error), + }); + } + } else { + result.push({ role: message.role, content: message.text }); + } + } + + return result; +} + +function convertTools(tools?: ChatTool[]) { + return tools?.map((tool: ChatTool) => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: tool.input, + }, + })); +} + +interface OllamaChunk { + model: string; + createdAt: string; + message: { + role: string; + content: string; + thinking?: string; + tool_calls?: { + id: string; + function: { + index: number; + name: string; + arguments: any; + }; + }[]; + }; + done: boolean; + done_reason?: "stop"; + total_duration?: number; + load_duration?: number; + prompt_eval_count?: number; + prompt_eval_duration?: number; + eval_count?: number; + eval_duration?: number; +} + +function convertChunk(chunk: OllamaChunk) { + return { + text: chunk.message.content, + reasoning: chunk.message.thinking, + toolCalls: chunk.message.tool_calls, + }; +} + export const runner = { chat: async (input: ChatJob["input"], options?: ChatJob["options"]) => { const baseUrl = getBaseUrl(options); - const tools = transformToolsToFunctions(input.tools); + const tools = convertTools(input.tools); + const messages = convertMessages(input.messages); const request = new Request(`${baseUrl}/api/chat`, { method: "POST", @@ -23,7 +113,7 @@ export const runner = { }, body: JSON.stringify({ model: input.model, - messages: input.messages, + messages: messages, temperature: input.temperature, tools: tools, stream: input.stream ?? false, @@ -35,7 +125,7 @@ export const runner = { return createHTTPJob(request, async (response: Response) => { if (input.stream) { - return createStreamingGenerator(response); + return createStreamingGenerator(response, convertChunk); } const data = await response.json(); @@ -82,4 +172,30 @@ export const runner = { })); }); }, + + embedding: async ( + input: EmbeddingJob["input"], + options?: EmbeddingJob["options"], + ) => { + const baseUrl = getBaseUrl(options); + + const request = new Request(`${baseUrl}/api/embed`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: input.model, + input: input.input, + }), + }); + + return createHTTPJob(request, async (response: Response) => { + const data = await response.json(); + + return { + embeddings: data.embeddings, + }; + }); + }, }; diff --git a/packages/fluent-ai/src/job/openai.ts b/packages/fluent-ai/src/job/openai.ts index 5e618f7..38ec6a7 100644 --- a/packages/fluent-ai/src/job/openai.ts +++ b/packages/fluent-ai/src/job/openai.ts @@ -33,7 +33,15 @@ export const runner = { return createHTTPJob(request, async (response: Response) => { if (input.stream) { - return createStreamingGenerator(response); + return createStreamingGenerator(response, (chunk: any) => { + const choice = chunk.choices[0]; + const delta = choice.delta; + return { + text: delta.content, + toolCalls: delta.tool_calls, + usage: transformUsageData(chunk.usage), + }; + }); } const data = await response.json(); diff --git a/packages/fluent-ai/src/job/openrouter.ts b/packages/fluent-ai/src/job/openrouter.ts index ea7e584..cc5d303 100644 --- a/packages/fluent-ai/src/job/openrouter.ts +++ b/packages/fluent-ai/src/job/openrouter.ts @@ -33,7 +33,15 @@ export const runner = { return createHTTPJob(request, async (response: Response) => { if (input.stream) { - return createStreamingGenerator(response); + return createStreamingGenerator(response, (chunk: any) => { + const choice = chunk.choices[0]; + const delta = choice.delta; + return { + text: delta.content, + toolCalls: delta.tool_calls, + usage: transformUsageData(chunk.usage), + }; + }); } const data = await response.json(); diff --git a/packages/fluent-ai/src/job/schema.ts b/packages/fluent-ai/src/job/schema.ts index 4120800..93da860 100644 --- a/packages/fluent-ai/src/job/schema.ts +++ b/packages/fluent-ai/src/job/schema.ts @@ -1,31 +1,76 @@ import { z } from "zod"; -const chatToolSchema = z.object({ - name: z.string(), - description: z.string(), - input: z.any(), // TODO: should be valid json schema +const systemMessageSchema = z.object({ + role: z.literal("system"), + text: z.string(), }); -const messagePartSchema = z.object({ - type: z.string(), +// TODO: support attaching files, images, etc. +const userMessageSchema = z.object({ + id: z.string().optional(), + role: z.literal("user"), + text: z.string(), +}); + +const assistantMessageSchema = z.object({ + id: z.string().optional(), + role: z.literal("assistant"), text: z.string().optional(), - toolCallId: z.string().optional(), - input: z.any().optional(), - output: z.any().optional(), - outputError: z.any().optional(), + reasoning: z.string().optional(), }); -const messageSchema = z.object({ - role: z.enum(["system", "user", "assistant", "tool"]), - parts: z.array(messagePartSchema), +const toolMessageSchema = z.object({ id: z.string().optional(), - threadId: z.string().optional(), - createdAt: z.date().optional(), + role: z.literal("tool"), + text: z.string(), + content: z.object({ + callId: z.string(), + name: z.string(), + args: z.any().optional(), + result: z.any().optional(), + error: z.any().optional(), + }), +}); + +const messagesSchema = z.union([ + systemMessageSchema, + userMessageSchema, + assistantMessageSchema, + toolMessageSchema, +]); + +const messageChunkSchema = z.object({ + text: z.string().optional(), + reasoning: z.string().optional(), + toolCalls: z + .array( + z.object({ + id: z.string(), + function: z.object({ + name: z.string(), + arguments: z.any(), + }), + }), + ) + .optional(), +}); + +export type SystemMessage = z.infer; +export type UserMessage = z.infer; +export type AssistantMessage = z.infer; +export type ToolMessage = z.infer; +export type Message = z.infer; +export type MessageChunk = z.infer; + +const chatToolSchema = z.object({ + name: z.string(), + description: z.string(), + input: z.any(), // TODO: should be valid json schema }); const chatInputSchema = z.object({ model: z.string(), - messages: z.array(z.any()), // TODO: fix any + messages: z.array(messagesSchema), temperature: z.number().optional(), maxTokens: z.number().optional(), stream: z.boolean().optional(), @@ -125,7 +170,7 @@ export const modelsJobSchema = z.object({ export const embeddingJobSchema = z.object({ type: z.literal("embedding"), - provider: z.enum(["voyage"]), + provider: z.enum(["voyage", "ollama"]), options: optionsSchema.optional(), input: embeddingInputSchema, output: embeddingOutputSchema.optional(), @@ -138,8 +183,6 @@ export const jobSchema = z.union([ embeddingJobSchema, ]); -export type MessagePart = z.infer; -export type Message = z.infer; export type ChatTool = z.infer; export type Job = z.infer; export type ImageJob = z.infer; diff --git a/packages/fluent-ai/src/job/utils.ts b/packages/fluent-ai/src/job/utils.ts index 92ec3f0..32982ed 100644 --- a/packages/fluent-ai/src/job/utils.ts +++ b/packages/fluent-ai/src/job/utils.ts @@ -1,5 +1,4 @@ -import { EventSourceParserStream } from "eventsource-parser/stream"; -import type { ChatTool } from "~/src/job/schema"; +import type { ChatTool, MessageChunk } from "~/src/job/schema"; export function getApiKey( options: { apiKey?: string } | undefined, @@ -29,22 +28,39 @@ export function transformUsageData(usage?: any) { : undefined; } -export async function* createStreamingGenerator(response: Response) { - const eventStream = response - .body!.pipeThrough(new TextDecoderStream()) - .pipeThrough(new EventSourceParserStream()); - const reader = eventStream.getReader(); - - try { - while (true) { - const { done, value } = await reader.read(); - if (done || value.data === "[DONE]") { - break; +export async function* createStreamingGenerator( + response: Response, + convertChunk: any, +) { + const decoder = new TextDecoder("utf-8"); + const reader = response.body!.getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + const decoded = decoder.decode(value, { stream: true }); + const lines = decoded.split("\n"); + for (const line of lines) { + if (line.startsWith(":")) { + continue; + } + + let chunk = line.trim(); + if (chunk.startsWith("data: ")) { + chunk = chunk.replace("data: ", "").trim(); + } + + if (chunk === "[DONE]") { + return; + } + + if (chunk) { + const parsed = JSON.parse(chunk); // let parse error throw + yield convertChunk(parsed); } - const chunk = JSON.parse(value.data); - yield { raw: chunk }; } - } finally { - reader.releaseLock(); } } diff --git a/packages/fluent-ai/test/job.test.ts b/packages/fluent-ai/test/job.test.ts index 8400913..dc2fcf6 100644 --- a/packages/fluent-ai/test/job.test.ts +++ b/packages/fluent-ai/test/job.test.ts @@ -5,7 +5,6 @@ import { openai, fal, type Job, - user, voyage, } from "~/src/index"; import { Runner } from "~/src/job/runner"; @@ -16,7 +15,7 @@ test("chat job", () => { type: "chat", input: { model: "test-model", - messages: [{ role: "user", content: "hello" }], + messages: [{ role: "user", text: "hello" }], }, }; @@ -27,7 +26,7 @@ test("chat job", () => { expect(job).toEqual( openrouter() .chat("test-model") - .messages([user("hello")]) + .messages([{ role: "user", text: "hello" }]) .build(), ); }); @@ -96,7 +95,7 @@ test("runner", () => { type: "chat", input: { model: "test-model", - messages: [{ role: "user", content: "hello" }], + messages: [{ role: "user", text: "hello" }], }, };