Skip to content

Commit

Permalink
allow llmClient to be optionally passed in (#352) (#364)
Browse files Browse the repository at this point in the history
* allow llmClient to be optionally passed in (#352)

* feat: allow llmClient to be optionally passed in

* update: add ollama client example from pr: #349

* update: README and changeset

* lint

---------

Co-authored-by: Arihan Varanasi <[email protected]>
  • Loading branch information
kamath and arihanv authored Jan 3, 2025
1 parent 89841fc commit 08907eb
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .changeset/spicy-singers-flow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@browserbasehq/stagehand": minor
---

exposed llmClient in stagehand constructor
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ This constructor is used to create an instance of Stagehand.
- `1`: SDK-level logging
- `2`: LLM-client level logging (most granular)
- `debugDom`: a `boolean` that draws bounding boxes around elements presented to the LLM during automation.
- `llmClient`: (optional) a custom `LLMClient` implementation.

- **Returns:**

Expand Down
48 changes: 48 additions & 0 deletions examples/external_client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { type ConstructorParams, type LogLine, Stagehand } from "../lib";
import { z } from "zod";
import { OllamaClient } from "./external_clients/ollama";

const StagehandConfig: ConstructorParams = {
env: "BROWSERBASE",
apiKey: process.env.BROWSERBASE_API_KEY,
projectId: process.env.BROWSERBASE_PROJECT_ID,
verbose: 1,
llmClient: new OllamaClient(
(message: LogLine) =>
console.log(`[stagehand::${message.category}] ${message.message}`),
false,
undefined,
"llama3.2",
),
debugDom: true,
};

async function example() {
const stagehand = new Stagehand(StagehandConfig);

await stagehand.init();
await stagehand.page.goto("https://news.ycombinator.com");

const headlines = await stagehand.page.extract({
instruction: "Extract only 3 stories from the Hacker News homepage.",
schema: z.object({
stories: z
.array(
z.object({
title: z.string(),
url: z.string(),
points: z.number(),
}),
)
.length(3),
}),
});

console.log(headlines);

await stagehand.close();
}

(async () => {
await example();
})();
313 changes: 313 additions & 0 deletions examples/external_clients/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import OpenAI, { type ClientOptions } from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import type { LLMCache } from "../../lib/cache/LLMCache";
import { validateZodSchema } from "../../lib/utils";
import {
type ChatCompletionOptions,
type ChatMessage,
LLMClient,
} from "../../lib/llm/LLMClient";
import type { LogLine } from "../../types/log";
import type { AvailableModel } from "../../types/model";
import type {
ChatCompletion,
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImage,
ChatCompletionContentPartText,
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
} from "openai/resources/chat";

export class OllamaClient extends LLMClient {
public type = "ollama" as const;
private client: OpenAI;
private cache: LLMCache | undefined;
public logger: (message: LogLine) => void;
private enableCaching: boolean;
public clientOptions: ClientOptions;

constructor(
logger: (message: LogLine) => void,
enableCaching = false,
cache: LLMCache | undefined,
modelName: "llama3.2",
clientOptions?: ClientOptions,
) {
super(modelName as AvailableModel);
this.client = new OpenAI({
...clientOptions,
baseURL: clientOptions?.baseURL || "http://localhost:11434/v1",
apiKey: "ollama",
});
this.logger = logger;
this.cache = cache;
this.enableCaching = enableCaching;
this.modelName = modelName as AvailableModel;
}

async createChatCompletion<T = ChatCompletion>(
options: ChatCompletionOptions,
retries = 3,
): Promise<T> {
const { image, requestId, ...optionsWithoutImageAndRequestId } = options;

// TODO: Implement vision support
if (image) {
throw new Error(
"Image provided. Vision is not currently supported for Ollama",
);
}

this.logger({
category: "ollama",
message: "creating chat completion",
level: 1,
auxiliary: {
options: {
value: JSON.stringify({
...optionsWithoutImageAndRequestId,
requestId,
}),
type: "object",
},
modelName: {
value: this.modelName,
type: "string",
},
},
});

const cacheOptions = {
model: this.modelName,
messages: options.messages,
temperature: options.temperature,
top_p: options.top_p,
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
image: image,
response_model: options.response_model,
};

if (options.image) {
const screenshotMessage: ChatMessage = {
role: "user",
content: [
{
type: "image_url",
image_url: {
url: `data:image/jpeg;base64,${options.image.buffer.toString("base64")}`,
},
},
...(options.image.description
? [{ type: "text", text: options.image.description }]
: []),
],
};

options.messages.push(screenshotMessage);
}

if (this.enableCaching && this.cache) {
const cachedResponse = await this.cache.get<T>(
cacheOptions,
options.requestId,
);

if (cachedResponse) {
this.logger({
category: "llm_cache",
message: "LLM cache hit - returning cached response",
level: 1,
auxiliary: {
requestId: {
value: options.requestId,
type: "string",
},
cachedResponse: {
value: JSON.stringify(cachedResponse),
type: "object",
},
},
});
return cachedResponse;
}

this.logger({
category: "llm_cache",
message: "LLM cache miss - no cached response found",
level: 1,
auxiliary: {
requestId: {
value: options.requestId,
type: "string",
},
},
});
}

let responseFormat = undefined;
if (options.response_model) {
responseFormat = zodResponseFormat(
options.response_model.schema,
options.response_model.name,
);
}

/* eslint-disable */
// Remove unsupported options
const { response_model, ...ollamaOptions } = {
...optionsWithoutImageAndRequestId,
model: this.modelName,
};

this.logger({
category: "ollama",
message: "creating chat completion",
level: 1,
auxiliary: {
ollamaOptions: {
value: JSON.stringify(ollamaOptions),
type: "object",
},
},
});

const formattedMessages: ChatCompletionMessageParam[] =
options.messages.map((message) => {
if (Array.isArray(message.content)) {
const contentParts = message.content.map((content) => {
if ("image_url" in content) {
const imageContent: ChatCompletionContentPartImage = {
image_url: {
url: content.image_url.url,
},
type: "image_url",
};
return imageContent;
} else {
const textContent: ChatCompletionContentPartText = {
text: content.text,
type: "text",
};
return textContent;
}
});

if (message.role === "system") {
const formattedMessage: ChatCompletionSystemMessageParam = {
...message,
role: "system",
content: contentParts.filter(
(content): content is ChatCompletionContentPartText =>
content.type === "text",
),
};
return formattedMessage;
} else if (message.role === "user") {
const formattedMessage: ChatCompletionUserMessageParam = {
...message,
role: "user",
content: contentParts,
};
return formattedMessage;
} else {
const formattedMessage: ChatCompletionAssistantMessageParam = {
...message,
role: "assistant",
content: contentParts.filter(
(content): content is ChatCompletionContentPartText =>
content.type === "text",
),
};
return formattedMessage;
}
}

const formattedMessage: ChatCompletionUserMessageParam = {
role: "user",
content: message.content,
};

return formattedMessage;
});

const body: ChatCompletionCreateParamsNonStreaming = {
...ollamaOptions,
model: this.modelName,
messages: formattedMessages,
response_format: responseFormat,
stream: false,
tools: options.tools?.filter((tool) => "function" in tool), // ensure only OpenAI compatibletools are used
};

const response = await this.client.chat.completions.create(body);

this.logger({
category: "ollama",
message: "response",
level: 1,
auxiliary: {
response: {
value: JSON.stringify(response),
type: "object",
},
requestId: {
value: requestId,
type: "string",
},
},
});

if (options.response_model) {
const extractedData = response.choices[0].message.content;
const parsedData = JSON.parse(extractedData);

if (!validateZodSchema(options.response_model.schema, parsedData)) {
if (retries > 0) {
return this.createChatCompletion(options, retries - 1);
}

throw new Error("Invalid response schema");
}

if (this.enableCaching) {
this.cache.set(
cacheOptions,
{
...parsedData,
},
options.requestId,
);
}

return parsedData;
}

if (this.enableCaching) {
this.logger({
category: "llm_cache",
message: "caching response",
level: 1,
auxiliary: {
requestId: {
value: options.requestId,
type: "string",
},
cacheOptions: {
value: JSON.stringify(cacheOptions),
type: "object",
},
response: {
value: JSON.stringify(response),
type: "object",
},
},
});
this.cache.set(cacheOptions, response, options.requestId);
}

return response as T;
}
}
Loading

0 comments on commit 08907eb

Please sign in to comment.