Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions packages/openai-adapters/src/apis/Gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
Completion,
CompletionCreateParamsNonStreaming,
CompletionCreateParamsStreaming,
CompletionUsage,
CreateEmbeddingResponse,
EmbeddingCreateParams,
Model,
Expand All @@ -22,6 +23,7 @@ import {
chatChunkFromDelta,
customFetch,
embedding,
usageChatChunk,
} from "../util.js";
import {
convertOpenAIToolToGeminiFunction,
Expand All @@ -37,6 +39,11 @@ import {
RerankCreateParams,
} from "./base.js";

type UsageInfo = Pick<
CompletionUsage,
"total_tokens" | "completion_tokens" | "prompt_tokens"
>;

export class GeminiApi implements BaseLlmApi {
apiBase: string = "https://generativelanguage.googleapis.com/v1beta/";

Expand Down Expand Up @@ -240,14 +247,20 @@ export class GeminiApi implements BaseLlmApi {
signal: AbortSignal,
): Promise<ChatCompletion> {
let completion = "";
let usage: UsageInfo | undefined = undefined;
for await (const chunk of this.chatCompletionStream(
{
...body,
stream: true,
},
signal,
)) {
completion += chunk.choices[0].delta.content;
if (chunk.choices.length > 0) {
completion += chunk.choices[0].delta.content || "";
}
if (chunk.usage) {
usage = chunk.usage;
}
}
return {
id: "",
Expand All @@ -266,12 +279,13 @@ export class GeminiApi implements BaseLlmApi {
},
},
],
usage: undefined,
usage,
};
}

async *handleStreamResponse(response: any, model: string) {
let buffer = "";
let usage: UsageInfo | undefined = undefined;
for await (const chunk of streamResponse(response as any)) {
buffer += chunk;
if (buffer.startsWith("[")) {
Expand Down Expand Up @@ -300,6 +314,15 @@ export class GeminiApi implements BaseLlmApi {
throw new Error(data.error.message);
}

// Check for usage metadata
if (data.usageMetadata) {
usage = {
prompt_tokens: data.usageMetadata.promptTokenCount || 0,
completion_tokens: data.usageMetadata.candidatesTokenCount || 0,
total_tokens: data.usageMetadata.totalTokenCount || 0,
};
}

// In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec
const contentParts = data?.candidates?.[0]?.content?.parts;
if (contentParts) {
Expand Down Expand Up @@ -338,6 +361,14 @@ export class GeminiApi implements BaseLlmApi {
buffer = "";
}
}

// Emit usage at the end if we have it
if (usage) {
yield usageChatChunk({
model,
usage,
});
}
}

async *chatCompletionStream(
Expand Down
5 changes: 5 additions & 0 deletions packages/openai-adapters/src/apis/OpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ export class OpenAIApi implements BaseLlmApi {
});
}
modifyChatBody<T extends ChatCompletionCreateParams>(body: T): T {
// Add stream_options to include usage in streaming responses
if (body.stream) {
(body as any).stream_options = { include_usage: true };
}

// o-series models - only apply for official OpenAI API
const isOfficialOpenAIAPI = this.apiBase === "https://api.openai.com/v1/";
if (isOfficialOpenAIAPI) {
Expand Down
31 changes: 29 additions & 2 deletions packages/openai-adapters/src/apis/Relace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
Completion,
CompletionCreateParamsNonStreaming,
CompletionCreateParamsStreaming,
CompletionUsage,
} from "openai/resources/completions.mjs";
import {
CreateEmbeddingResponse,
Expand All @@ -16,14 +17,24 @@ import {
import { Model } from "openai/resources/models.mjs";
import { z } from "zod";
import { OpenAIConfigSchema } from "../types.js";
import { chatChunk, chatCompletion, customFetch } from "../util.js";
import {
chatChunk,
chatCompletion,
customFetch,
usageChatChunk,
} from "../util.js";
import {
BaseLlmApi,
CreateRerankResponse,
FimCreateParamsStreaming,
RerankCreateParams,
} from "./base.js";

type UsageInfo = Pick<
CompletionUsage,
"total_tokens" | "completion_tokens" | "prompt_tokens"
>;

// Relace only supports apply through a /v1/apply endpoint
export class RelaceApi implements BaseLlmApi {
private apiBase = "https://instantapply.endpoint.relace.run/v1/";
Expand All @@ -41,6 +52,7 @@ export class RelaceApi implements BaseLlmApi {
signal: AbortSignal,
): Promise<ChatCompletion> {
let content = "";
let usage: UsageInfo | undefined = undefined;

// Convert the non-streaming params to streaming params
const streamingBody: ChatCompletionCreateParamsStreaming = {
Expand All @@ -52,12 +64,18 @@ export class RelaceApi implements BaseLlmApi {
streamingBody,
signal,
)) {
content += chunk.choices[0]?.delta?.content || "";
if (chunk.choices.length > 0) {
content += chunk.choices[0]?.delta?.content || "";
}
if (chunk.usage) {
usage = chunk.usage;
}
}

return chatCompletion({
content,
model: body.model,
usage,
});
}

Expand Down Expand Up @@ -115,6 +133,15 @@ export class RelaceApi implements BaseLlmApi {
content: mergedCode,
model: body.model,
});

yield usageChatChunk({
model: body.model,
usage: {
prompt_tokens: result.usage.prompt_tokens || 0,
completion_tokens: result.usage.completion_tokens || 0,
total_tokens: result.usage.total_tokens,
},
});
}

completionNonStream(
Expand Down
1 change: 1 addition & 0 deletions packages/openai-adapters/src/test/adapter-test-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ export const createAdapterTests = (testConfig: AdapterTestConfig) => {
model: "gpt-4",
messages: [{ role: "user", content: "hello" }],
stream: true,
stream_options: { include_usage: true },
...testConfig.customBodyOptions,
},
},
Expand Down
65 changes: 58 additions & 7 deletions packages/openai-adapters/src/test/main.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ const TESTS: Omit<ModelConfig & { options?: TestConfigOptions }, "name">[] = [
model: "gpt-4o",
apiKey: process.env.OPENAI_API_KEY!,
roles: ["chat"],
options: {
skipTools: false,
expectUsage: true,
},
},
{
provider: "openai",
model: "gpt-4o-mini",
apiKey: process.env.OPENAI_API_KEY!,
roles: ["chat"],
options: {
skipTools: false,
expectUsage: true,
},
},
{
provider: "anthropic",
Expand All @@ -67,14 +81,31 @@ const TESTS: Omit<ModelConfig & { options?: TestConfigOptions }, "name">[] = [
model: "gemini-1.5-flash-latest",
apiKey: process.env.GEMINI_API_KEY!,
roles: ["chat"],
options: {
skipTools: false,
expectUsage: true,
},
},
{
provider: "gemini",
model: "gemini-2.5-flash",
apiKey: process.env.GEMINI_API_KEY!,
roles: ["chat"],
options: {
skipTools: false,
expectUsage: true,
},
},
{
provider: "mistral",
model: "codestral",
apiKey: process.env.MISTRAL_API_KEY!,
roles: ["chat"],
options: {
skipTools: false,
expectUsage: true,
},
},
// {
// provider: "mistral",
// model: "codestral-latest",
// apiKey: process.env.MISTRAL_API_KEY!,
// apiBase: "https://api.mistral.ai/v1",
// roles: ["autocomplete"],
// },
// {
// provider: "deepseek",
// model: "deepseek-coder",
Expand Down Expand Up @@ -111,6 +142,26 @@ const TESTS: Omit<ModelConfig & { options?: TestConfigOptions }, "name">[] = [
apiKey: process.env.VOYAGE_API_KEY!,
roles: ["rerank"],
},
{
provider: "relace",
model: "instant-apply",
apiKey: process.env.RELACE_API_KEY!,
roles: ["chat"],
options: {
skipTools: true,
expectUsage: true,
},
},
{
provider: "inception",
model: "mercury-coder",
apiKey: process.env.INCEPTION_API_KEY!,
roles: ["chat"],
options: {
skipTools: false,
expectUsage: true,
},
},
// {
// provider: "cohere",
// model: "rerank-v3.5",
Expand Down
14 changes: 11 additions & 3 deletions packages/openai-adapters/src/test/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,17 @@ export function testChat(
expect(usage).toBeDefined();
expect(usage!.completion_tokens).toBeGreaterThan(0);
expect(usage!.prompt_tokens).toBeGreaterThan(0);
expect(usage!.total_tokens).toEqual(
usage!.prompt_tokens + usage!.completion_tokens,
);
// Gemini 2.5 models have thinking tokens, so total_tokens >= prompt + completion
// Other models should have total_tokens = prompt + completion
if (model.includes("gemini-2.5") || model.includes("gemini-2.0")) {
expect(usage!.total_tokens).toBeGreaterThanOrEqual(
usage!.prompt_tokens + usage!.completion_tokens,
);
} else {
expect(usage!.total_tokens).toEqual(
usage!.prompt_tokens + usage!.completion_tokens,
);
}
}
});

Expand Down
Loading