Skip to content

Commit 3b86586

Browse files
committed
fix structured output detection
1 parent 078734b commit 3b86586

File tree

6 files changed

+49
-23
lines changed

6 files changed

+49
-23
lines changed

src/lib/components/inference-playground/code-snippets.svelte

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
<script lang="ts">
22
import { type ConversationClass } from "$lib/state/conversations.svelte";
3-
import { structuredForbiddenProviders } from "$lib/state/models.svelte";
43
import { token } from "$lib/state/token.svelte.js";
54
import { billing } from "$lib/state/billing.svelte";
65
import { isCustomModel } from "$lib/types.js";
@@ -59,7 +58,7 @@
5958
// eslint-disable-next-line @typescript-eslint/no-explicit-any
6059
} as any;
6160
// eslint-disable-next-line @typescript-eslint/no-explicit-any
62-
if (data.structuredOutput && !structuredForbiddenProviders.includes(conversation.data.provider as any)) {
61+
if (data.structuredOutput && conversation.isStructuredOutputAllowed) {
6362
opts.structured_output = data.structuredOutput;
6463
}
6564

src/lib/components/inference-playground/generation-config.svelte

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
<script lang="ts">
22
import type { ConversationClass } from "$lib/state/conversations.svelte.js";
3-
import { structuredForbiddenProviders } from "$lib/state/models.svelte.js";
43
import { maxAllowedTokens } from "$lib/utils/business.svelte.js";
4+
import { cn } from "$lib/utils/cn.js";
55
import { isNumber } from "$lib/utils/is.js";
66
import { watch } from "runed";
77
import IconX from "~icons/carbon/close";
8+
import ExtraParamsModal, { openExtraParamsModal } from "./extra-params-modal.svelte";
89
import { GENERATION_CONFIG_KEYS, GENERATION_CONFIG_SETTINGS } from "./generation-config-settings.js";
910
import StructuredOutputModal, { openStructuredOutputModal } from "./structured-output-modal.svelte";
10-
import ExtraParamsModal, { openExtraParamsModal } from "./extra-params-modal.svelte";
11-
import { cn } from "$lib/utils/cn.js";
1211
1312
interface Props {
1413
conversation: ConversationClass;
@@ -103,7 +102,7 @@
103102
</label>
104103

105104
<!-- eslint-disable-next-line @typescript-eslint/no-explicit-any -->
106-
{#if !structuredForbiddenProviders.includes(conversation.data.provider as any)}
105+
{#if conversation.isStructuredOutputAllowed}
107106
<label class="mt-2 flex cursor-pointer items-center justify-between" for="structured-output">
108107
<span class="text-sm font-medium text-gray-900 dark:text-gray-300">Structured Output</span>
109108
<div class="flex items-center gap-2">

src/lib/state/conversations.svelte.ts

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ import {
44
} from "$lib/components/inference-playground/generation-config-settings.js";
55
import { addToast } from "$lib/components/toaster.svelte.js";
66
import { AbortManager } from "$lib/spells/abort-manager.svelte";
7-
import { PipelineTag, Provider, type ConversationMessage, type GenerationStatistics, type Model } from "$lib/types.js";
7+
import { PipelineTag, type ConversationMessage, type GenerationStatistics, type Model } from "$lib/types.js";
88
import { handleNonStreamingResponse, handleStreamingResponse, estimateTokens } from "$lib/utils/business.svelte.js";
99
import { omit, snapshot } from "$lib/utils/object.svelte";
10-
import { models, structuredForbiddenProviders } from "./models.svelte";
10+
import { models } from "./models.svelte";
1111
import { pricing } from "./pricing.svelte.js";
1212
import { DEFAULT_PROJECT_ID, ProjectEntity, projects } from "./projects.svelte";
1313
import { token } from "./token.svelte";
@@ -107,9 +107,7 @@ export class ConversationClass {
107107
}
108108

109109
get isStructuredOutputAllowed() {
110-
const forbiddenProvider =
111-
this.data.provider && structuredForbiddenProviders.includes(this.data.provider as Provider);
112-
return !forbiddenProvider;
110+
return models.supportsStructuredOutput(this.model, this.data.provider);
113111
}
114112

115113
get isStructuredOutputEnabled() {

src/lib/state/models.svelte.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { page } from "$app/state";
2-
import { Provider, type CustomModel } from "$lib/types.js";
2+
import { Provider, type CustomModel, type Model } from "$lib/types.js";
33
import { edit, randomPick } from "$lib/utils/array.js";
44
import { safeParse } from "$lib/utils/json.js";
55
import typia from "typia";
@@ -10,13 +10,6 @@ const LOCAL_STORAGE_KEY = "hf_inference_playground_custom_models";
1010

1111
const pageData = $derived(page.data as PageData);
1212

13-
export const structuredForbiddenProviders: Provider[] = [
14-
Provider.Hyperbolic,
15-
Provider.Nebius,
16-
Provider.Novita,
17-
Provider.Sambanova,
18-
];
19-
2013
class Models {
2114
remote = $derived(pageData.models);
2215
trending = $derived(this.remote.toSorted((a, b) => b.trendingScore - a.trendingScore).slice(0, 5));
@@ -74,6 +67,13 @@ class Models {
7467
c.update({ modelId: randomPick(models.trending)?.id });
7568
});
7669
}
70+
71+
supportsStructuredOutput(model: Model | CustomModel, provider?: string) {
72+
if (typia.is<CustomModel>(model)) return true;
73+
const routerDataEntry = pageData.routerData.data.find(d => d.id === model.id);
74+
if (!routerDataEntry) return false;
75+
return routerDataEntry.providers.find(p => p.provider === provider)?.supports_structured_output ?? false;
76+
}
7777
}
7878

7979
export const models = new Models();

src/lib/utils/business.svelte.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
2828
import OpenAI from "openai";
2929
import { images } from "$lib/state/images.svelte.js";
3030
import { projects } from "$lib/state/projects.svelte.js";
31-
import { structuredForbiddenProviders } from "$lib/state/models.svelte.js";
3231
import { modifySnippet } from "$lib/utils/snippets.js";
32+
import { models } from "$lib/state/models.svelte";
3333

3434
type ChatCompletionInputMessageChunk =
3535
NonNullable<ChatCompletionInputMessage["content"]> extends string | (infer U)[] ? U : never;
@@ -89,7 +89,7 @@ function getResponseFormatObj(conversation: ConversationClass | Conversation) {
8989
const data = conversation instanceof ConversationClass ? conversation.data : conversation;
9090
const json = safeParse(data.structuredOutput?.schema ?? "");
9191
// eslint-disable-next-line @typescript-eslint/no-explicit-any
92-
if (json && data.structuredOutput?.enabled && !structuredForbiddenProviders.includes(data.provider as any)) {
92+
if (json && data.structuredOutput?.enabled && models.supportsStructuredOutput(conversation.model, data.provider)) {
9393
switch (data.provider) {
9494
case "cohere": {
9595
return {
@@ -366,7 +366,7 @@ export function getInferenceSnippet(
366366
if (
367367
opts?.structured_output?.schema &&
368368
opts.structured_output.enabled &&
369-
!structuredForbiddenProviders.includes(provider as Provider)
369+
models.supportsStructuredOutput(conversation.model, provider)
370370
) {
371371
return {
372372
...s,

src/routes/+page.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,44 @@
1+
import type { Provider } from "$lib/types.js";
12
import type { PageLoad } from "./$types.js";
23
import type { ApiModelsResponse } from "./api/models/+server.js";
34

5+
export type RouterData = {
6+
object: string;
7+
data: Datum[];
8+
};
9+
10+
type Datum = {
11+
id: string;
12+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
13+
object: any;
14+
created: number;
15+
owned_by: string;
16+
providers: ProviderElement[];
17+
};
18+
19+
type ProviderElement = {
20+
provider: Provider;
21+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
22+
status: any;
23+
context_length?: number;
24+
pricing?: Pricing;
25+
supports_tools?: boolean;
26+
supports_structured_output?: boolean;
27+
};
28+
29+
type Pricing = {
30+
input: number;
31+
output: number;
32+
};
33+
434
export const load: PageLoad = async ({ fetch }) => {
535
const [modelsRes, routerRes] = await Promise.all([
636
fetch("/api/models"),
737
fetch("https://router.huggingface.co/v1/models"),
838
]);
939

1040
const models: ApiModelsResponse = await modelsRes.json();
11-
const routerData = await routerRes.json();
41+
const routerData = (await routerRes.json()) as RouterData;
1242

1343
return {
1444
...models,

0 commit comments

Comments
 (0)