Skip to content

Commit 8ba47d1

Browse files
ItsOnlyBinarybrunobergher
authored andcommitted
Make ollama models info transport work like lmstudio (#7679)
1 parent 965dfc0 commit 8ba47d1

File tree

6 files changed

+99
-12
lines changed

6 files changed

+99
-12
lines changed

src/core/webview/__tests__/webviewMessageHandler.spec.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,48 @@ describe("webviewMessageHandler - requestLmStudioModels", () => {
136136
})
137137
})
138138

139+
describe("webviewMessageHandler - requestOllamaModels", () => {
140+
beforeEach(() => {
141+
vi.clearAllMocks()
142+
mockClineProvider.getState = vi.fn().mockResolvedValue({
143+
apiConfiguration: {
144+
ollamaModelId: "model-1",
145+
ollamaBaseUrl: "http://localhost:1234",
146+
},
147+
})
148+
})
149+
150+
it("successfully fetches models from Ollama", async () => {
151+
const mockModels: ModelRecord = {
152+
"model-1": {
153+
maxTokens: 4096,
154+
contextWindow: 8192,
155+
supportsPromptCache: false,
156+
description: "Test model 1",
157+
},
158+
"model-2": {
159+
maxTokens: 8192,
160+
contextWindow: 16384,
161+
supportsPromptCache: false,
162+
description: "Test model 2",
163+
},
164+
}
165+
166+
mockGetModels.mockResolvedValue(mockModels)
167+
168+
await webviewMessageHandler(mockClineProvider, {
169+
type: "requestOllamaModels",
170+
})
171+
172+
expect(mockGetModels).toHaveBeenCalledWith({ provider: "ollama", baseUrl: "http://localhost:1234" })
173+
174+
expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
175+
type: "ollamaModels",
176+
ollamaModels: mockModels,
177+
})
178+
})
179+
})
180+
139181
describe("webviewMessageHandler - requestRouterModels", () => {
140182
beforeEach(() => {
141183
vi.clearAllMocks()

src/core/webview/webviewMessageHandler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ export const webviewMessageHandler = async (
797797
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
798798
provider.postMessageToWebview({
799799
type: "ollamaModels",
800-
ollamaModels: Object.keys(result.value.models),
800+
ollamaModels: result.value.models,
801801
})
802802
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
803803
provider.postMessageToWebview({
@@ -842,7 +842,7 @@ export const webviewMessageHandler = async (
842842
if (Object.keys(ollamaModels).length > 0) {
843843
provider.postMessageToWebview({
844844
type: "ollamaModels",
845-
ollamaModels: Object.keys(ollamaModels),
845+
ollamaModels: ollamaModels,
846846
})
847847
}
848848
} catch (error) {

src/shared/ExtensionMessage.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ export interface ExtensionMessage {
148148
clineMessage?: ClineMessage
149149
routerModels?: RouterModels
150150
openAiModels?: string[]
151-
ollamaModels?: string[]
151+
ollamaModels?: ModelRecord
152152
lmStudioModels?: ModelRecord
153153
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
154154
huggingFaceModels?: Array<{

webview-ui/src/components/settings/providers/Ollama.tsx

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
1111
import { vscode } from "@src/utils/vscode"
1212

1313
import { inputEventTransform } from "../transforms"
14+
import { ModelRecord } from "@roo/api"
1415

1516
type OllamaProps = {
1617
apiConfiguration: ProviderSettings
@@ -20,7 +21,7 @@ type OllamaProps = {
2021
export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaProps) => {
2122
const { t } = useAppTranslation()
2223

23-
const [ollamaModels, setOllamaModels] = useState<string[]>([])
24+
const [ollamaModels, setOllamaModels] = useState<ModelRecord>({})
2425
const routerModels = useRouterModels()
2526

2627
const handleInputChange = useCallback(
@@ -40,7 +41,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
4041
switch (message.type) {
4142
case "ollamaModels":
4243
{
43-
const newModels = message.ollamaModels ?? []
44+
const newModels = message.ollamaModels ?? {}
4445
setOllamaModels(newModels)
4546
}
4647
break
@@ -61,7 +62,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
6162
if (!selectedModel) return false
6263

6364
// Check if model exists in local ollama models
64-
if (ollamaModels.length > 0 && ollamaModels.includes(selectedModel)) {
65+
if (Object.keys(ollamaModels).length > 0 && selectedModel in ollamaModels) {
6566
return false // Model is available locally
6667
}
6768

@@ -116,15 +117,13 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
116117
</div>
117118
</div>
118119
)}
119-
{ollamaModels.length > 0 && (
120+
{Object.keys(ollamaModels).length > 0 && (
120121
<VSCodeRadioGroup
121122
value={
122-
ollamaModels.includes(apiConfiguration?.ollamaModelId || "")
123-
? apiConfiguration?.ollamaModelId
124-
: ""
123+
(apiConfiguration?.ollamaModelId || "") in ollamaModels ? apiConfiguration?.ollamaModelId : ""
125124
}
126125
onChange={handleInputChange("ollamaModelId")}>
127-
{ollamaModels.map((model) => (
126+
{Object.keys(ollamaModels).map((model) => (
128127
<VSCodeRadio key={model} value={model} checked={apiConfiguration?.ollamaModelId === model}>
129128
{model}
130129
</VSCodeRadio>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import { useQuery } from "@tanstack/react-query"
2+
3+
import { ModelRecord } from "@roo/api"
4+
import { ExtensionMessage } from "@roo/ExtensionMessage"
5+
6+
import { vscode } from "@src/utils/vscode"
7+
8+
const getOllamaModels = async () =>
9+
new Promise<ModelRecord>((resolve, reject) => {
10+
const cleanup = () => {
11+
window.removeEventListener("message", handler)
12+
}
13+
14+
const timeout = setTimeout(() => {
15+
cleanup()
16+
reject(new Error("Ollama models request timed out"))
17+
}, 10000)
18+
19+
const handler = (event: MessageEvent) => {
20+
const message: ExtensionMessage = event.data
21+
22+
if (message.type === "ollamaModels") {
23+
clearTimeout(timeout)
24+
cleanup()
25+
26+
if (message.ollamaModels) {
27+
resolve(message.ollamaModels)
28+
} else {
29+
reject(new Error("No Ollama models in response"))
30+
}
31+
}
32+
}
33+
34+
window.addEventListener("message", handler)
35+
vscode.postMessage({ type: "requestOllamaModels" })
36+
})
37+
38+
export const useOllamaModels = (modelId?: string) =>
39+
useQuery({ queryKey: ["ollamaModels"], queryFn: () => (modelId ? getOllamaModels() : {}) })

webview-ui/src/components/ui/hooks/useSelectedModel.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,23 @@ import type { ModelRecord, RouterModels } from "@roo/api"
6464
import { useRouterModels } from "./useRouterModels"
6565
import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
6666
import { useLmStudioModels } from "./useLmStudioModels"
67+
import { useOllamaModels } from "./useOllamaModels"
6768

6869
export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
6970
const provider = apiConfiguration?.apiProvider || "anthropic"
7071
const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
7172
const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined
73+
const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined
7274

7375
const routerModels = useRouterModels()
7476
const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
7577
const lmStudioModels = useLmStudioModels(lmStudioModelId)
78+
const ollamaModels = useOllamaModels(ollamaModelId)
7679

7780
const { id, info } =
7881
apiConfiguration &&
7982
(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
83+
(typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") &&
8084
typeof routerModels.data !== "undefined" &&
8185
typeof openRouterModelProviders.data !== "undefined"
8286
? getSelectedModel({
@@ -85,6 +89,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
8589
routerModels: routerModels.data,
8690
openRouterModelProviders: openRouterModelProviders.data,
8791
lmStudioModels: lmStudioModels.data,
92+
ollamaModels: ollamaModels.data,
8893
})
8994
: { id: anthropicDefaultModelId, info: undefined }
9095

@@ -109,12 +114,14 @@ function getSelectedModel({
109114
routerModels,
110115
openRouterModelProviders,
111116
lmStudioModels,
117+
ollamaModels,
112118
}: {
113119
provider: ProviderName
114120
apiConfiguration: ProviderSettings
115121
routerModels: RouterModels
116122
openRouterModelProviders: Record<string, ModelInfo>
117123
lmStudioModels: ModelRecord | undefined
124+
ollamaModels: ModelRecord | undefined
118125
}): { id: string; info: ModelInfo | undefined } {
119126
// the `undefined` case are used to show the invalid selection to prevent
120127
// users from seeing the default model if their selection is invalid
@@ -255,7 +262,7 @@ function getSelectedModel({
255262
}
256263
case "ollama": {
257264
const id = apiConfiguration.ollamaModelId ?? ""
258-
const info = routerModels.ollama && routerModels.ollama[id]
265+
const info = ollamaModels && ollamaModels[apiConfiguration.ollamaModelId!]
259266
return {
260267
id,
261268
info: info || undefined,

0 commit comments

Comments
 (0)