Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/core/webview/__tests__/webviewMessageHandler.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,48 @@ describe("webviewMessageHandler - requestLmStudioModels", () => {
})
})

describe("webviewMessageHandler - requestOllamaModels", () => {
beforeEach(() => {
vi.clearAllMocks()
mockClineProvider.getState = vi.fn().mockResolvedValue({
apiConfiguration: {
ollamaModelId: "model-1",
ollamaBaseUrl: "http://localhost:1234",
},
})
})

it("successfully fetches models from Ollama", async () => {
const mockModels: ModelRecord = {
"model-1": {
maxTokens: 4096,
contextWindow: 8192,
supportsPromptCache: false,
description: "Test model 1",
},
"model-2": {
maxTokens: 8192,
contextWindow: 16384,
supportsPromptCache: false,
description: "Test model 2",
},
}

mockGetModels.mockResolvedValue(mockModels)

await webviewMessageHandler(mockClineProvider, {
type: "requestOllamaModels",
})

expect(mockGetModels).toHaveBeenCalledWith({ provider: "ollama", baseUrl: "http://localhost:1234" })

expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({
type: "ollamaModels",
ollamaModels: mockModels,
})
})
})

describe("webviewMessageHandler - requestRouterModels", () => {
beforeEach(() => {
vi.clearAllMocks()
Expand Down
4 changes: 2 additions & 2 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ export const webviewMessageHandler = async (
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "ollamaModels",
ollamaModels: Object.keys(result.value.models),
ollamaModels: result.value.models,
})
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
Expand Down Expand Up @@ -669,7 +669,7 @@ export const webviewMessageHandler = async (
if (Object.keys(ollamaModels).length > 0) {
provider.postMessageToWebview({
type: "ollamaModels",
ollamaModels: Object.keys(ollamaModels),
ollamaModels: ollamaModels,
})
}
} catch (error) {
Expand Down
2 changes: 1 addition & 1 deletion src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ export interface ExtensionMessage {
clineMessage?: ClineMessage
routerModels?: RouterModels
openAiModels?: string[]
ollamaModels?: string[]
ollamaModels?: ModelRecord
lmStudioModels?: ModelRecord
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
huggingFaceModels?: Array<{
Expand Down
15 changes: 7 additions & 8 deletions webview-ui/src/components/settings/providers/Ollama.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels"
import { vscode } from "@src/utils/vscode"

import { inputEventTransform } from "../transforms"
import { ModelRecord } from "@roo/api"

type OllamaProps = {
apiConfiguration: ProviderSettings
Expand All @@ -20,7 +21,7 @@ type OllamaProps = {
export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaProps) => {
const { t } = useAppTranslation()

const [ollamaModels, setOllamaModels] = useState<string[]>([])
const [ollamaModels, setOllamaModels] = useState<ModelRecord>({})
const routerModels = useRouterModels()

const handleInputChange = useCallback(
Expand All @@ -40,7 +41,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
switch (message.type) {
case "ollamaModels":
{
const newModels = message.ollamaModels ?? []
const newModels = message.ollamaModels ?? {}
setOllamaModels(newModels)
}
break
Expand All @@ -61,7 +62,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
if (!selectedModel) return false

// Check if model exists in local ollama models
if (ollamaModels.length > 0 && ollamaModels.includes(selectedModel)) {
if (Object.keys(ollamaModels).length > 0 && selectedModel in ollamaModels) {
return false // Model is available locally
}

Expand Down Expand Up @@ -116,15 +117,13 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro
</div>
</div>
)}
{ollamaModels.length > 0 && (
{Object.keys(ollamaModels).length > 0 && (
<VSCodeRadioGroup
value={
ollamaModels.includes(apiConfiguration?.ollamaModelId || "")
? apiConfiguration?.ollamaModelId
: ""
(apiConfiguration?.ollamaModelId || "") in ollamaModels ? apiConfiguration?.ollamaModelId : ""
}
onChange={handleInputChange("ollamaModelId")}>
{ollamaModels.map((model) => (
{Object.keys(ollamaModels).map((model) => (
<VSCodeRadio key={model} value={model} checked={apiConfiguration?.ollamaModelId === model}>
{model}
</VSCodeRadio>
Expand Down
39 changes: 39 additions & 0 deletions webview-ui/src/components/ui/hooks/useOllamaModels.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { useQuery } from "@tanstack/react-query"

import { ModelRecord } from "@roo/api"
import { ExtensionMessage } from "@roo/ExtensionMessage"

import { vscode } from "@src/utils/vscode"

const getOllamaModels = async () =>
new Promise<ModelRecord>((resolve, reject) => {
const cleanup = () => {
window.removeEventListener("message", handler)
}

const timeout = setTimeout(() => {
cleanup()
reject(new Error("Ollama models request timed out"))
}, 10000)

const handler = (event: MessageEvent) => {
const message: ExtensionMessage = event.data

if (message.type === "ollamaModels") {
clearTimeout(timeout)
cleanup()

if (message.ollamaModels) {
resolve(message.ollamaModels)
} else {
reject(new Error("No Ollama models in response"))
}
}
}

window.addEventListener("message", handler)
vscode.postMessage({ type: "requestOllamaModels" })
})

export const useOllamaModels = (modelId?: string) =>
useQuery({ queryKey: ["ollamaModels"], queryFn: () => (modelId ? getOllamaModels() : {}) })
9 changes: 8 additions & 1 deletion webview-ui/src/components/ui/hooks/useSelectedModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,23 @@ import type { ModelRecord, RouterModels } from "@roo/api"
import { useRouterModels } from "./useRouterModels"
import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
import { useLmStudioModels } from "./useLmStudioModels"
import { useOllamaModels } from "./useOllamaModels"

export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
const provider = apiConfiguration?.apiProvider || "anthropic"
const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined
const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined

const routerModels = useRouterModels()
const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
const lmStudioModels = useLmStudioModels(lmStudioModelId)
const ollamaModels = useOllamaModels(ollamaModelId)

const { id, info } =
apiConfiguration &&
(typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") &&
(typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") &&
typeof routerModels.data !== "undefined" &&
typeof openRouterModelProviders.data !== "undefined"
? getSelectedModel({
Expand All @@ -84,6 +88,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
routerModels: routerModels.data,
openRouterModelProviders: openRouterModelProviders.data,
lmStudioModels: lmStudioModels.data,
ollamaModels: ollamaModels.data,
})
: { id: anthropicDefaultModelId, info: undefined }

Expand All @@ -108,12 +113,14 @@ function getSelectedModel({
routerModels,
openRouterModelProviders,
lmStudioModels,
ollamaModels,
}: {
provider: ProviderName
apiConfiguration: ProviderSettings
routerModels: RouterModels
openRouterModelProviders: Record<string, ModelInfo>
lmStudioModels: ModelRecord | undefined
ollamaModels: ModelRecord | undefined
}): { id: string; info: ModelInfo | undefined } {
// the `undefined` case are used to show the invalid selection to prevent
// users from seeing the default model if their selection is invalid
Expand Down Expand Up @@ -254,7 +261,7 @@ function getSelectedModel({
}
case "ollama": {
const id = apiConfiguration.ollamaModelId ?? ""
const info = routerModels.ollama && routerModels.ollama[id]
const info = ollamaModels && ollamaModels[apiConfiguration.ollamaModelId!]
return {
id,
info: info || undefined,
Expand Down
Loading