diff --git a/CLAUDE.md b/CLAUDE.md index 568585f4..f1307530 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,41 +80,46 @@ Claudish supports local models via: Local model APIs (LM Studio, Ollama) report `prompt_tokens` as the **full conversation context** each request, not incremental tokens. The `writeTokenFile` function uses assignment (`=`) not accumulation (`+=`) for input tokens to handle this correctly. -## Three-Layer Adapter Architecture (v5.14.0+) +## Three-Layer Provider Architecture -The translation pipeline has three decoupled layers: +All provider-specific logic lives in `providers/provider-profiles.ts`. Four resolvers and one effective-definition resolver. `createHandlerForProvider(def, modelName, apiKey, targetModel, port, opts)` composes them into a ComposedHandler. -### Layer 1: FormatConverter — wire format translation -Translates between Claude API format and target model's wire format (messages, tools, payload). -Each converter declares its stream format via `getStreamFormat()`. -- **Interface**: `adapters/format-converter.ts` -- **Implementations**: OpenAIAdapter, AnthropicPassthroughAdapter, GeminiAdapter, CodexAdapter, OllamaCloudAdapter, LiteLLMAdapter -- **Message/tool conversion**: `handlers/shared/format/openai-messages.ts`, `openai-tools.ts` +### Layer 1: APIFormat — wire format +- **Base**: `BaseAPIFormat` (`adapters/base-api-format.ts`) +- **Classes**: OpenAIAPIFormat, AnthropicAPIFormat, GeminiAPIFormat, CodexAPIFormat, OllamaAPIFormat, LiteLLMAPIFormat, OpenRouterAPIFormat +- **Resolver**: `resolveAPIFormat(def, modelName)` switches on `def.transport` +- **Key methods**: `convertMessages()`, `convertTools()`, `buildPayload()`, `getStreamFormat()`, `parseResponse()` -### Layer 2: ModelTranslator — model dialect translation -Translates model-specific dialect differences (context windows, thinking→reasoning_effort, vision rules). -- **Interface**: `adapters/model-translator.ts` -- **Implementations**: GLMAdapter, GrokAdapter, MiniMaxAdapter, DeepSeekAdapter, QwenAdapter, CodexAdapter -- **Selection**: `AdapterManager` auto-selects based on model ID +### Layer 2: ModelDialect — model quirks +- **Base**: `BaseModelDialect` (`adapters/base-model-dialect.ts`) +- **Classes**: GrokModelDialect, GeminiModelDialect, DeepSeekModelDialect, QwenModelDialect, MiniMaxModelDialect, GLMModelDialect, XiaomiModelDialect +- **Resolver**: `resolveModelDialect(modelId)` matches model family patterns +- **Key methods**: `processTextContent()`, `getContextWindow()`, `supportsVision()`, `prepareRequest()` ### Layer 3: ProviderTransport — HTTP transport -Handles auth, endpoints, headers, rate limiting. Optionally overrides stream format for aggregators. -- **Interface**: `providers/transport/types.ts` -- **Stream format override**: LiteLLM and OpenRouter implement `overrideStreamFormat()` → `"openai-sse"` - -### Composition in ComposedHandler -``` -ComposedHandler = FormatConverter (explicit adapter) + ModelTranslator (auto-selected) + ProviderTransport -``` - -**Stream parser selection** (3-tier priority): +- **Base**: `BaseTransport` (`providers/transport/base.ts`) owns a `RequestQueue` +- **Subclasses**: `ApiKeyTransport` (Bearer auth), `OAuthTransport` (token refresh) +- **Classes**: OpenAI, Gemini, CodeAssist, Anthropic, OpenRouter, OllamaCloud, LiteLLM, Vertex, Poe, Local, Zen, KimiCoding +- **Resolver**: `resolveTransport(def, modelName, apiKey)` switches on `def.transport` +- **Rate limiting**: `onResponse()`, `shouldRetry()`, `calculateDelay()` as overridable methods (Gemini parses quotaResetDelay, OpenRouter tracks X-RateLimit headers) + +### How to add a provider +1. Add entry to `BUILTIN_PROVIDERS` in `provider-definitions.ts` with `transport` type +2. Add case to `resolveTransport` and `resolveAPIFormat` in `provider-profiles.ts` +3. If model has quirks, add case to `resolveModelDialect` and create a ModelDialect class +4. If transport needs custom auth/queue behavior, create a transport subclass extending `BaseTransport` + +### Additional resolvers +- `resolveMiddlewares(modelId)`: selects middleware instances (e.g., GeminiThoughtSignatureMiddleware for Gemini models) +- `resolveEffective(def, modelName, apiKey)`: handles definition swaps (zen + minimax) and publicKeyFallback +- `resolveModelDialect(modelId)`: returns `BaseModelDialect | null` (null = no dialect quirks) + +### Stream parser selection ```typescript -transport.overrideStreamFormat() ?? modelAdapter.getStreamFormat() ?? providerAdapter.getStreamFormat() +transport.overrideStreamFormat() ?? modelAdapter.getStreamFormat() ?? adapter.getStreamFormat() ``` -**Adding a new provider**: Add one entry to `PROVIDER_PROFILES` table in `providers/provider-profiles.ts`. -**Adding a new model**: Create a ModelTranslator adapter, register in `adapters/adapter-manager.ts`. -**Verifying wiring**: `claudish --probe ` shows the full adapter composition. +**Verifying wiring**: `claudish --probe ` shows the full composition. ### Stream Parsers Located in `handlers/shared/stream-parsers/`: diff --git a/README.md b/README.md index 3e404700..e76382b6 100644 --- a/README.md +++ b/README.md @@ -905,18 +905,20 @@ claudish "your prompt" 2. Find available port (random or specified) 3. Start local proxy on http://127.0.0.1:PORT 4. Spawn: claude --auto-approve --env ANTHROPIC_BASE_URL=http://127.0.0.1:PORT -5. Proxy translates: Anthropic API → OpenRouter API +5. Proxy translates via three-layer pipeline: + - ProviderTransport: auth, endpoint, rate limiting + - APIFormat: wire format conversion (OpenAI, Gemini, Anthropic, etc.) + - ModelDialect: model-specific quirks (reasoning filters, tool format) 6. Stream output in real-time 7. Cleanup proxy on exit ``` ### Request Flow -**Normal Mode (OpenRouter):** ``` -Claude Code → Anthropic API format → Local Proxy → OpenRouter API format → OpenRouter - ↓ -Claude Code ← Anthropic API format ← Local Proxy ← OpenRouter API format ← OpenRouter +Claude Code → Anthropic API → Local Proxy → [APIFormat + ProviderTransport] → Any Provider + ↓ +Claude Code ← Anthropic API ← Local Proxy ← [Stream Parser] ← Any Provider ``` **Monitor Mode (Anthropic Passthrough):** diff --git a/packages/cli/scripts/smoke/providers.ts b/packages/cli/scripts/smoke/providers.ts index 7234e6f5..7b8ec29e 100644 --- a/packages/cli/scripts/smoke/providers.ts +++ b/packages/cli/scripts/smoke/providers.ts @@ -7,7 +7,7 @@ */ import type { RemoteProvider } from "../../src/handlers/shared/remote-provider-types.js"; -import { getRegisteredRemoteProviders } from "../../src/providers/remote-provider-registry.js"; +import { getAllProviders, toRemoteProvider } from "../../src/providers/provider-definitions.js"; import type { SmokeProviderConfig, WireFormat } from "./types.js"; // Providers to skip in v1 smoke tests @@ -154,7 +154,9 @@ function getBaseUrl(provider: RemoteProvider): string { * @returns Array of SmokeProviderConfig for providers ready to test. */ export function discoverProviders(filterName?: string): SmokeProviderConfig[] { - const all = getRegisteredRemoteProviders(); + const all = getAllProviders() + .filter((def) => !def.isLocal && def.baseUrl !== "" && def.name !== "qwen" && def.name !== "native-anthropic") + .map(toRemoteProvider); return all .filter((p) => { diff --git a/packages/cli/src/adapters/anthropic-api-format.ts b/packages/cli/src/adapters/anthropic-api-format.ts index 71242fe1..ac8ed3b3 100644 --- a/packages/cli/src/adapters/anthropic-api-format.ts +++ b/packages/cli/src/adapters/anthropic-api-format.ts @@ -26,10 +26,6 @@ export class AnthropicAPIFormat extends BaseAPIFormat { }; } - shouldHandle(modelId: string): boolean { - return false; // Not auto-selected; always explicitly passed - } - getName(): string { return "AnthropicAPIFormat"; } @@ -114,6 +110,15 @@ export class AnthropicAPIFormat extends BaseAPIFormat { return "anthropic-sse"; } + override parseResponse(data: any): { content: string; usage?: { input: number; output: number } } { + const blocks = data?.content ?? []; + const content = blocks.map((b: any) => b.text ?? "").join(""); + const usage = data?.usage + ? { input: data.usage.input_tokens ?? 0, output: data.usage.output_tokens ?? 0 } + : undefined; + return { content, usage }; + } + override getContextWindow(): number { // Try catalog lookup first (handles kimi/minimax model name variants) const catalogEntry = lookupModel(this.modelId); diff --git a/packages/cli/src/adapters/api-format.ts b/packages/cli/src/adapters/api-format.ts index fe794157..f765c588 100644 --- a/packages/cli/src/adapters/api-format.ts +++ b/packages/cli/src/adapters/api-format.ts @@ -34,4 +34,8 @@ export interface APIFormat { textContent: string, accumulatedText: string ): import("./base-api-format.js").AdapterResult; + + /** Parse a non-streaming response body into content + usage. + * Default handles OpenAI format. Override for other wire formats. */ + parseResponse(data: any): { content: string; usage?: { input: number; output: number } }; } diff --git a/packages/cli/src/adapters/base-api-format.ts b/packages/cli/src/adapters/base-api-format.ts index 1a492599..7c375022 100644 --- a/packages/cli/src/adapters/base-api-format.ts +++ b/packages/cli/src/adapters/base-api-format.ts @@ -68,11 +68,6 @@ export abstract class BaseAPIFormat implements APIFormat, ModelDialect { */ abstract processTextContent(textContent: string, accumulatedText: string): AdapterResult; - /** - * Check if this format/dialect should be used for the given model - */ - abstract shouldHandle(modelId: string): boolean; - /** * Get name for logging */ @@ -172,6 +167,20 @@ export abstract class BaseAPIFormat implements APIFormat, ModelDialect { return "openai-sse"; } + /** + * Parse a non-streaming response into content + usage. + * Default handles OpenAI Chat Completions format. + * Override for Gemini, Anthropic, etc. + */ + parseResponse(data: any): { content: string; usage?: { input: number; output: number } } { + const choice = data?.choices?.[0]; + const content = choice?.message?.content ?? ""; + const usage = data?.usage + ? { input: data.usage.prompt_tokens ?? 0, output: data.usage.completion_tokens ?? 0 } + : undefined; + return { content, usage }; + } + /** * Context window size for this model (tokens). * Used for token tracking and context-left-percent calculation. @@ -257,10 +266,6 @@ export class DefaultAPIFormat extends BaseAPIFormat { }; } - shouldHandle(modelId: string): boolean { - return false; // Default is fallback - } - getName(): string { return "DefaultAPIFormat"; } diff --git a/packages/cli/src/adapters/base-model-dialect.ts b/packages/cli/src/adapters/base-model-dialect.ts new file mode 100644 index 00000000..4f9cacd4 --- /dev/null +++ b/packages/cli/src/adapters/base-model-dialect.ts @@ -0,0 +1,84 @@ +/** + * BaseModelDialect — abstract base for model dialect implementations (Layer 2). + * + * Model dialects handle per-model-family quirks: context windows, parameter mappings + * (thinking -> reasoning_effort), vision support, tool name limits, and text + * post-processing (reasoning filters, XML parsing, token stripping). + * + * This class is for pure dialect adapters that do NOT define a wire format. + * Wire format adapters (OpenAI, Gemini, Anthropic, etc.) extend BaseAPIFormat instead. + */ + +import type { ModelDialect } from "./model-dialect.js"; +import type { StreamFormat } from "../providers/transport/types.js"; +import type { AdapterResult } from "./base-api-format.js"; +import type { ModelPricing } from "../handlers/shared/remote-provider-types.js"; +import { getModelPricing } from "../handlers/shared/remote-provider-types.js"; + +export abstract class BaseModelDialect implements ModelDialect { + protected modelId: string; + + constructor(modelId: string) { + this.modelId = modelId; + } + + /** + * Process text content and extract any model-specific tool call formats. + */ + abstract processTextContent(textContent: string, accumulatedText: string): AdapterResult; + + /** + * Get name for logging. + */ + abstract getName(): string; + + /** + * Maximum tool name length allowed by this model's API. + * Returns null if no limit (default). + */ + getToolNameLimit(): number | null { + return null; + } + + /** + * Handle any request preparation before sending to the model. + */ + prepareRequest(request: any, originalRequest: any): any { + return request; + } + + /** + * Reset internal state between requests. + */ + reset(): void {} + + /** + * Context window size for this model (tokens). + */ + getContextWindow(): number { + return 200_000; + } + + /** + * Whether this model supports vision/image input. + * Default false; override in dialects for models with vision support. + */ + supportsVision(): boolean { + return false; + } + + /** + * Stream format: dialects have no opinion (wire format is the adapter's job). + * Returns undefined so ComposedHandler falls through to the format adapter. + */ + getStreamFormat(): StreamFormat | undefined { + return undefined; + } + + /** + * Pricing info for this model. + */ + getPricing(providerName: string): ModelPricing { + return getModelPricing(providerName, this.modelId); + } +} diff --git a/packages/cli/src/adapters/codex-api-format.ts b/packages/cli/src/adapters/codex-api-format.ts index 81ba952a..e5f94561 100644 --- a/packages/cli/src/adapters/codex-api-format.ts +++ b/packages/cli/src/adapters/codex-api-format.ts @@ -11,7 +11,7 @@ * This format handles Codex models only. All other OpenAI models use OpenAIAPIFormat. */ -import { BaseAPIFormat, type AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseAPIFormat, type AdapterResult } from "./base-api-format.js"; import type { StreamFormat } from "../providers/transport/types.js"; export class CodexAPIFormat extends BaseAPIFormat { @@ -27,10 +27,6 @@ export class CodexAPIFormat extends BaseAPIFormat { }; } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "codex"); - } - getName(): string { return "CodexAPIFormat"; } diff --git a/packages/cli/src/adapters/deepseek-model-dialect.ts b/packages/cli/src/adapters/deepseek-model-dialect.ts index ec6e5525..628cdcce 100644 --- a/packages/cli/src/adapters/deepseek-model-dialect.ts +++ b/packages/cli/src/adapters/deepseek-model-dialect.ts @@ -5,10 +5,11 @@ * - Strips unsupported thinking params (DeepSeek thinks automatically) */ -import { BaseAPIFormat, AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult } from "./base-api-format.js"; import { log } from "../logger.js"; -export class DeepSeekModelDialect extends BaseAPIFormat { +export class DeepSeekModelDialect extends BaseModelDialect { processTextContent(textContent: string, accumulatedText: string): AdapterResult { return { cleanedText: textContent, @@ -35,10 +36,6 @@ export class DeepSeekModelDialect extends BaseAPIFormat { return request; } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "deepseek"); - } - getName(): string { return "DeepSeekModelDialect"; } diff --git a/packages/cli/src/adapters/dialect-manager.ts b/packages/cli/src/adapters/dialect-manager.ts deleted file mode 100644 index 99a5ca0e..00000000 --- a/packages/cli/src/adapters/dialect-manager.ts +++ /dev/null @@ -1,64 +0,0 @@ -/** - * DialectManager — selects the appropriate Layer 2 ModelDialect for a given model. - * - * This allows ComposedHandler to apply model-specific quirks independent of - * which Layer 1 APIFormat or Layer 3 ProviderTransport are used: - * - Grok: XML function calls - * - Gemini: Thought signatures in reasoning_details - * - DeepSeek, GLM, etc.: thinking param stripping / mapping - */ - -import { BaseAPIFormat, DefaultAPIFormat } from "./base-api-format.js"; -import { GrokModelDialect } from "./grok-model-dialect.js"; -import { GeminiAPIFormat } from "./gemini-api-format.js"; -import { CodexAPIFormat } from "./codex-api-format.js"; -import { OpenAIAPIFormat } from "./openai-api-format.js"; -import { QwenModelDialect } from "./qwen-model-dialect.js"; -import { MiniMaxModelDialect } from "./minimax-model-dialect.js"; -import { DeepSeekModelDialect } from "./deepseek-model-dialect.js"; -import { GLMModelDialect } from "./glm-model-dialect.js"; -import { XiaomiModelDialect } from "./xiaomi-model-dialect.js"; - -export class DialectManager { - private adapters: BaseAPIFormat[]; - private defaultAdapter: DefaultAPIFormat; - - constructor(modelId: string) { - // Register all available dialects/formats - this.adapters = [ - new GrokModelDialect(modelId), - new GeminiAPIFormat(modelId), - new CodexAPIFormat(modelId), // Must be before OpenAIAPIFormat (codex matches first) - new OpenAIAPIFormat(modelId), - new QwenModelDialect(modelId), - new MiniMaxModelDialect(modelId), - new DeepSeekModelDialect(modelId), - new GLMModelDialect(modelId), - new XiaomiModelDialect(modelId), - ]; - this.defaultAdapter = new DefaultAPIFormat(modelId); - } - - /** - * Get the appropriate dialect/format for the current model - */ - getAdapter(): BaseAPIFormat { - for (const adapter of this.adapters) { - if (adapter.shouldHandle(this.defaultAdapter["modelId"])) { - return adapter; - } - } - return this.defaultAdapter; - } - - /** - * Check if current model needs special handling - */ - needsTransformation(): boolean { - return this.getAdapter() !== this.defaultAdapter; - } -} - -// Backward-compatible alias -/** @deprecated Use DialectManager */ -export { DialectManager as AdapterManager }; diff --git a/packages/cli/src/adapters/gemini-api-format.ts b/packages/cli/src/adapters/gemini-api-format.ts index bc447237..d929c6bd 100644 --- a/packages/cli/src/adapters/gemini-api-format.ts +++ b/packages/cli/src/adapters/gemini-api-format.ts @@ -6,61 +6,18 @@ * - Tool conversion: Claude tools → Gemini function declarations * - Payload building: generationConfig, systemInstruction, thinkingConfig * - thoughtSignature tracking across requests (required for Gemini 3/2.5 thinking) - * - Reasoning text filtering (removes leaked internal monologue) + * + * Reasoning text filtering lives in GeminiModelDialect (Layer 2). * * Used with GeminiProviderTransport (direct API) and GeminiCodeAssistProviderTransport (OAuth). */ -import { BaseAPIFormat, type AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseAPIFormat, type AdapterResult } from "./base-api-format.js"; import { convertToolsToGemini } from "../handlers/shared/gemini-schema.js"; import { filterIdentity } from "../handlers/shared/openai-compat.js"; import { log } from "../logger.js"; import type { StreamFormat } from "../providers/transport/types.js"; -/** - * Patterns that indicate internal reasoning/monologue that should be filtered. - * Gemini sometimes leaks reasoning as regular text instead of keeping it in thinking blocks. - */ -const REASONING_PATTERNS = [ - /^Wait,?\s+I(?:'m|\s+am)\s+\w+ing\b/i, - /^Wait,?\s+(?:if|that|the|this|I\s+(?:need|should|will|have|already))/i, - /^Wait[.!]?\s*$/i, - /^Let\s+me\s+(think|check|verify|see|look|analyze|consider|first|start)/i, - /^Let's\s+(check|see|look|start|first|try|think|verify|examine|analyze)/i, - /^I\s+need\s+to\s+/i, - /^O[kK](?:ay)?[.,!]?\s*(?:so|let|I|now|first)?/i, - /^[Hh]mm+/, - /^So[,.]?\s+(?:I|let|first|now|the)/i, - /^(?:First|Next|Then|Now)[,.]?\s+(?:I|let|we)/i, - /^(?:Thinking\s+about|Considering)/i, - /^I(?:'ll|\s+will)\s+(?:first|now|start|begin|try|check|fix|look|examine|modify|create|update|read|investigate|adjust|improve|integrate|mark|also|verify|need|rethink|add|help|use|run|search|find|explore|analyze|review|test|implement|write|make|set|get|see|open|close|save|load|fetch|call|send|build|compile|execute|process|handle|parse|format|validate|clean|clear|remove|delete|move|copy|rename|install|configure|setup|initialize|prepare|work|continue|proceed|ensure|confirm)/i, - /^I\s+should\s+/i, - /^I\s+will\s+(?:first|now|start|verify|check|create|modify|look|need|also|add|help|use|run|search|find|explore|analyze|review|test|implement|write)/i, - /^(?:Debug|Checking|Verifying|Looking\s+at):/i, - /^I\s+also\s+(?:notice|need|see|want)/i, - /^The\s+(?:goal|issue|problem|idea|plan)\s+is/i, - /^In\s+the\s+(?:old|current|previous|new|existing)\s+/i, - /^`[^`]+`\s+(?:is|has|does|needs|should|will|doesn't|hasn't)/i, -]; - -const REASONING_CONTINUATION_PATTERNS = [ - /^And\s+(?:then|I|now|so)/i, - /^And\s+I(?:'ll|\s+will)/i, - /^But\s+(?:I|first|wait|actually|the|if)/i, - /^Actually[,.]?\s+/i, - /^Also[,.]?\s+(?:I|the|check|note)/i, - /^\d+\.\s+(?:I|First|Check|Run|Create|Update|Read|Modify|Add|Fix|Look)/i, - /^-\s+(?:I|First|Check|Run|Create|Update|Read|Modify|Add|Fix)/i, - /^Or\s+(?:I|just|we|maybe|perhaps)/i, - /^Since\s+(?:I|the|this|we|it)/i, - /^Because\s+(?:I|the|this|we|it)/i, - /^If\s+(?:I|the|this|we|it)\s+/i, - /^This\s+(?:is|means|requires|should|will|confirms|suggests)/i, - /^That\s+(?:means|is|should|will|explains|confirms)/i, - /^Lines?\s+\d+/i, - /^The\s+`[^`]+`\s+(?:is|has|contains|needs|should)/i, -]; - export class GeminiAPIFormat extends BaseAPIFormat { /** * Map of tool_use_id → { name, thoughtSignature }. @@ -69,10 +26,6 @@ export class GeminiAPIFormat extends BaseAPIFormat { */ private toolCallMap = new Map(); - /** Reasoning filter state */ - private inReasoningBlock = false; - private reasoningBlockDepth = 0; - constructor(modelId: string) { super(modelId); } @@ -258,78 +211,38 @@ export class GeminiAPIFormat extends BaseAPIFormat { } } - // ─── Text Processing (reasoning filter) ─────────────────────────── + // ─── Text Processing (passthrough — reasoning filter is in GeminiModelDialect) ── processTextContent(textContent: string, _accumulatedText: string): AdapterResult { - if (!textContent || textContent.trim() === "") { - return { cleanedText: textContent, extractedToolCalls: [], wasTransformed: false }; - } - - const lines = textContent.split("\n"); - const cleanedLines: string[] = []; - let wasFiltered = false; - - for (const line of lines) { - const trimmed = line.trim(); - - if (!trimmed) { - cleanedLines.push(line); - continue; - } - - if (this.isReasoningLine(trimmed)) { - log(`[GeminiAPIFormat] Filtered reasoning: "${trimmed.substring(0, 50)}..."`); - wasFiltered = true; - this.inReasoningBlock = true; - this.reasoningBlockDepth++; - continue; - } - - if (this.inReasoningBlock && this.isReasoningContinuation(trimmed)) { - log(`[GeminiAPIFormat] Filtered reasoning continuation: "${trimmed.substring(0, 50)}..."`); - wasFiltered = true; - continue; - } - - if (this.inReasoningBlock && trimmed.length > 20 && !this.isReasoningContinuation(trimmed)) { - this.inReasoningBlock = false; - this.reasoningBlockDepth = 0; - } - - cleanedLines.push(line); - } - - const cleanedText = cleanedLines.join("\n"); - return { - cleanedText: wasFiltered ? cleanedText : textContent, + cleanedText: textContent, extractedToolCalls: [], - wasTransformed: wasFiltered, + wasTransformed: false, }; } - private isReasoningLine(line: string): boolean { - return REASONING_PATTERNS.some((pattern) => pattern.test(line)); - } - - private isReasoningContinuation(line: string): boolean { - return REASONING_CONTINUATION_PATTERNS.some((pattern) => pattern.test(line)); - } - // ─── Format metadata ───────────────────────────────────────────── override getStreamFormat(): StreamFormat { return "gemini-sse"; } + override parseResponse(data: any): { content: string; usage?: { input: number; output: number } } { + const parts = data?.candidates?.[0]?.content?.parts; + const content = parts?.map((p: any) => p.text ?? "").join("") ?? ""; + const meta = data?.usageMetadata; + const usage = meta + ? { input: meta.promptTokenCount ?? 0, output: meta.candidatesTokenCount ?? 0 } + : undefined; + return { content, usage }; + } + /** - * Reset reasoning filter state between requests. + * Reset state between requests. * NOTE: toolCallMap is intentionally NOT cleared — it persists across requests * because Gemini requires thoughtSignatures from previous responses. */ override reset(): void { - this.inReasoningBlock = false; - this.reasoningBlockDepth = 0; // Do NOT clear toolCallMap or toolNameMap } @@ -337,10 +250,6 @@ export class GeminiAPIFormat extends BaseAPIFormat { return 1_048_576; // Gemini models have 1M context (2^20 tokens) } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "gemini") || modelId.toLowerCase().includes("google/"); - } - getName(): string { return "GeminiAPIFormat"; } diff --git a/packages/cli/src/adapters/gemini-model-dialect.ts b/packages/cli/src/adapters/gemini-model-dialect.ts new file mode 100644 index 00000000..1cc2e619 --- /dev/null +++ b/packages/cli/src/adapters/gemini-model-dialect.ts @@ -0,0 +1,141 @@ +/** + * GeminiModelDialect — Gemini model-specific quirks (Layer 2 dialect). + * + * Handles reasoning text filtering: Gemini sometimes leaks internal monologue + * as regular text instead of keeping it in thinking blocks. + * + * Applied regardless of transport path (direct Gemini API, OpenRouter, LiteLLM). + */ + +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult } from "./base-api-format.js"; +import { log } from "../logger.js"; + +const REASONING_PATTERNS = [ + /^Wait,?\s+I(?:'m|\s+am)\s+\w+ing\b/i, + /^Wait,?\s+(?:if|that|the|this|I\s+(?:need|should|will|have|already))/i, + /^Wait[.!]?\s*$/i, + /^Let\s+me\s+(think|check|verify|see|look|analyze|consider|first|start)/i, + /^Let's\s+(check|see|look|start|first|try|think|verify|examine|analyze)/i, + /^I\s+need\s+to\s+/i, + /^O[kK](?:ay)?[.,!]?\s*(?:so|let|I|now|first)?/i, + /^[Hh]mm+/, + /^So[,.]?\s+(?:I|let|first|now|the)/i, + /^(?:First|Next|Then|Now)[,.]?\s+(?:I|let|we)/i, + /^(?:Thinking\s+about|Considering)/i, + /^I(?:'ll|\s+will)\s+(?:first|now|start|begin|try|check|fix|look|examine|modify|create|update|read|investigate|adjust|improve|integrate|mark|also|verify|need|rethink|add|help|use|run|search|find|explore|analyze|review|test|implement|write|make|set|get|see|open|close|save|load|fetch|call|send|build|compile|execute|process|handle|parse|format|validate|clean|clear|remove|delete|move|copy|rename|install|configure|setup|initialize|prepare|work|continue|proceed|ensure|confirm)/i, + /^I\s+should\s+/i, + /^I\s+will\s+(?:first|now|start|verify|check|create|modify|look|need|also|add|help|use|run|search|find|explore|analyze|review|test|implement|write)/i, + /^(?:Debug|Checking|Verifying|Looking\s+at):/i, + /^I\s+also\s+(?:notice|need|see|want)/i, + /^The\s+(?:goal|issue|problem|idea|plan)\s+is/i, + /^In\s+the\s+(?:old|current|previous|new|existing)\s+/i, + /^`[^`]+`\s+(?:is|has|does|needs|should|will|doesn't|hasn't)/i, +]; + +const REASONING_CONTINUATION_PATTERNS = [ + /^And\s+(?:then|I|now|so)/i, + /^And\s+I(?:'ll|\s+will)/i, + /^But\s+(?:I|first|wait|actually|the|if)/i, + /^Actually[,.]?\s+/i, + /^Also[,.]?\s+(?:I|the|check|note)/i, + /^\d+\.\s+(?:I|First|Check|Run|Create|Update|Read|Modify|Add|Fix|Look)/i, + /^-\s+(?:I|First|Check|Run|Create|Update|Read|Modify|Add|Fix)/i, + /^Or\s+(?:I|just|we|maybe|perhaps)/i, + /^Since\s+(?:I|the|this|we|it)/i, + /^Because\s+(?:I|the|this|we|it)/i, + /^If\s+(?:I|the|this|we|it)\s+/i, + /^This\s+(?:is|means|requires|should|will|confirms|suggests)/i, + /^That\s+(?:means|is|should|will|explains|confirms)/i, + /^Lines?\s+\d+/i, + /^The\s+`[^`]+`\s+(?:is|has|contains|needs|should)/i, +]; + +export class GeminiModelDialect extends BaseModelDialect { + private inReasoningBlock = false; + + getName(): string { + return "GeminiModelDialect"; + } + + override prepareRequest(request: any, originalRequest: any): any { + // Inject output format instruction to suppress leaked reasoning text + if (request.messages && Array.isArray(request.messages)) { + const msg = `CRITICAL INSTRUCTION FOR OUTPUT FORMAT: +1. Keep ALL internal reasoning INTERNAL. Never output your thought process as visible text. +2. Do NOT start responses with phrases like "Wait, I'm...", "Let me think...", "Okay, so...", "First, I need to..." +3. Do NOT output numbered planning steps or internal debugging statements. +4. Only output: final responses, tool calls, and code. Nothing else. +5. When calling tools, proceed directly without announcing your intentions. +6. Your internal thinking should use the reasoning/thinking API, not visible text output.`; + if (request.messages.length > 0 && request.messages[0].role === "system") { + request.messages[0].content += "\n\n" + msg; + } else { + request.messages.unshift({ role: "system", content: msg }); + } + } + return request; + } + + processTextContent(textContent: string, _accumulatedText: string): AdapterResult { + if (!textContent || textContent.trim() === "") { + return { cleanedText: textContent, extractedToolCalls: [], wasTransformed: false }; + } + + const lines = textContent.split("\n"); + const cleanedLines: string[] = []; + let wasFiltered = false; + + for (const line of lines) { + const trimmed = line.trim(); + + if (!trimmed) { + cleanedLines.push(line); + continue; + } + + if (REASONING_PATTERNS.some((p) => p.test(trimmed))) { + log(`[GeminiModelDialect] Filtered reasoning: "${trimmed.substring(0, 50)}..."`); + wasFiltered = true; + this.inReasoningBlock = true; + continue; + } + + if (this.inReasoningBlock && REASONING_CONTINUATION_PATTERNS.some((p) => p.test(trimmed))) { + log( + `[GeminiModelDialect] Filtered reasoning continuation: "${trimmed.substring(0, 50)}..."` + ); + wasFiltered = true; + continue; + } + + if ( + this.inReasoningBlock && + !REASONING_CONTINUATION_PATTERNS.some((p) => p.test(trimmed)) + ) { + this.inReasoningBlock = false; + } + + cleanedLines.push(line); + } + + const cleanedText = cleanedLines.join("\n"); + return { + cleanedText: wasFiltered ? cleanedText : textContent, + extractedToolCalls: [], + wasTransformed: wasFiltered, + }; + } + + override reset(): void { + this.inReasoningBlock = false; + } + + override getContextWindow(): number { + return 1_048_576; // Gemini models have 1M context (2^20 tokens) + } + + override supportsVision(): boolean { + return true; // All Gemini models support vision + } +} diff --git a/packages/cli/src/adapters/glm-model-dialect.ts b/packages/cli/src/adapters/glm-model-dialect.ts index c76a9773..0c42f75e 100644 --- a/packages/cli/src/adapters/glm-model-dialect.ts +++ b/packages/cli/src/adapters/glm-model-dialect.ts @@ -7,11 +7,12 @@ * - Vision support detection (sourced from model-catalog.ts) */ -import { BaseAPIFormat, AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult } from "./base-api-format.js"; import { log } from "../logger.js"; import { lookupModel } from "./model-catalog.js"; -export class GLMModelDialect extends BaseAPIFormat { +export class GLMModelDialect extends BaseModelDialect { processTextContent(textContent: string, accumulatedText: string): AdapterResult { return { cleanedText: textContent, @@ -30,14 +31,6 @@ export class GLMModelDialect extends BaseAPIFormat { return request; } - shouldHandle(modelId: string): boolean { - return ( - matchesModelFamily(modelId, "glm-") || - matchesModelFamily(modelId, "chatglm-") || - modelId.toLowerCase().includes("zhipu/") - ); - } - getName(): string { return "GLMModelDialect"; } diff --git a/packages/cli/src/adapters/grok-model-dialect.ts b/packages/cli/src/adapters/grok-model-dialect.ts index a168c9f6..9dfd5e2a 100644 --- a/packages/cli/src/adapters/grok-model-dialect.ts +++ b/packages/cli/src/adapters/grok-model-dialect.ts @@ -10,11 +10,12 @@ * This dialect translates that to Claude Code's expected tool_calls format. */ -import { BaseAPIFormat, AdapterResult, ToolCall, matchesModelFamily } from "./base-api-format.js"; +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult, ToolCall } from "./base-api-format.js"; import { log } from "../logger.js"; import { lookupModel } from "./model-catalog.js"; -export class GrokModelDialect extends BaseAPIFormat { +export class GrokModelDialect extends BaseModelDialect { private xmlBuffer: string = ""; processTextContent(textContent: string, accumulatedText: string): AdapterResult { @@ -83,6 +84,16 @@ export class GrokModelDialect extends BaseAPIFormat { override prepareRequest(request: any, originalRequest: any): any { const modelId = this.modelId || ""; + // Inject tool format instruction to prevent XML function calls + if (request.messages && Array.isArray(request.messages)) { + const msg = "IMPORTANT: When calling tools, you MUST use the OpenAI tool_calls format with JSON. NEVER use XML format like ."; + if (request.messages.length > 0 && request.messages[0].role === "system") { + request.messages[0].content += "\n\n" + msg; + } else { + request.messages.unshift({ role: "system", content: msg }); + } + } + if (originalRequest.thinking) { // Only Grok 3 Mini supports reasoning_effort const supportsReasoningEffort = modelId.includes("mini"); @@ -131,10 +142,6 @@ export class GrokModelDialect extends BaseAPIFormat { return params; } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "grok") || modelId.toLowerCase().includes("x-ai/"); - } - getName(): string { return "GrokModelDialect"; } @@ -143,6 +150,10 @@ export class GrokModelDialect extends BaseAPIFormat { return lookupModel(this.modelId)?.contextWindow ?? 131_072; } + override supportsVision(): boolean { + return lookupModel(this.modelId)?.supportsVision ?? false; + } + /** * Reset internal state (useful between requests) */ diff --git a/packages/cli/src/adapters/index.ts b/packages/cli/src/adapters/index.ts index d184b635..b55fa3d3 100644 --- a/packages/cli/src/adapters/index.ts +++ b/packages/cli/src/adapters/index.ts @@ -1,16 +1,8 @@ /** - * Model format and dialect implementations + * Model format and dialect exports */ export { BaseAPIFormat, DefaultAPIFormat } from "./base-api-format.js"; +export { BaseModelDialect } from "./base-model-dialect.js"; export type { ToolCall, AdapterResult } from "./base-api-format.js"; export { GrokModelDialect } from "./grok-model-dialect.js"; -export { DialectManager } from "./dialect-manager.js"; - -// Backward-compatible aliases -export { - BaseAPIFormat as BaseModelAdapter, - DefaultAPIFormat as DefaultAdapter, -} from "./base-api-format.js"; -export { GrokModelDialect as GrokAdapter } from "./grok-model-dialect.js"; -export { DialectManager as AdapterManager } from "./dialect-manager.js"; diff --git a/packages/cli/src/adapters/litellm-api-format.ts b/packages/cli/src/adapters/litellm-api-format.ts index 91e70d5a..65709a72 100644 --- a/packages/cli/src/adapters/litellm-api-format.ts +++ b/packages/cli/src/adapters/litellm-api-format.ts @@ -36,10 +36,6 @@ export class LiteLLMAPIFormat extends DefaultAPIFormat { return "LiteLLMAPIFormat"; } - shouldHandle(modelId: string): boolean { - return false; // Always used explicitly, not via DialectManager matching - } - supportsVision(): boolean { return this.visionSupported; } diff --git a/packages/cli/src/adapters/local-adapter.ts b/packages/cli/src/adapters/local-adapter.ts index c45073bb..5ad9c040 100644 --- a/packages/cli/src/adapters/local-adapter.ts +++ b/packages/cli/src/adapters/local-adapter.ts @@ -12,10 +12,11 @@ */ import { BaseAPIFormat, type AdapterResult } from "./base-api-format.js"; -import { DialectManager } from "./dialect-manager.js"; +import type { BaseModelDialect } from "./base-model-dialect.js"; +import { resolveModelDialect } from "../providers/provider-profiles.js"; import { log } from "../logger.js"; -interface SamplingParams { +export interface SamplingParams { temperature: number; top_p: number; top_k: number; @@ -24,15 +25,14 @@ interface SamplingParams { } export class LocalModelAdapter extends BaseAPIFormat { - private innerAdapter: BaseAPIFormat; + private innerAdapter: BaseAPIFormat | BaseModelDialect; private providerName: string; constructor(modelId: string, providerName: string) { super(modelId); this.providerName = providerName; - const manager = new DialectManager(modelId); - this.innerAdapter = manager.getAdapter(); + this.innerAdapter = resolveModelDialect(modelId); } // ─── Text processing delegates to inner adapter ─────────────────── @@ -41,10 +41,6 @@ export class LocalModelAdapter extends BaseAPIFormat { return this.innerAdapter.processTextContent(textContent, accumulatedText); } - shouldHandle(modelId: string): boolean { - return true; // Always used explicitly - } - getName(): string { return `LocalModelAdapter(${this.innerAdapter.getName()})`; } @@ -75,14 +71,6 @@ export class LocalModelAdapter extends BaseAPIFormat { messages[0].content += this.buildSystemGuidance(claudeRequest.tools?.length || 0); } - // Qwen /no_think toggle - if (this.modelId.toLowerCase().includes("qwen") && process.env.CLAUDISH_QWEN_NO_THINK === "1") { - if (messages.length > 0 && messages[0].role === "system") { - messages[0].content = "/no_think\n\n" + messages[0].content; - log(`[${this.getName()}] Added /no_think to disable Qwen thinking mode`); - } - } - return messages; } @@ -164,29 +152,20 @@ export class LocalModelAdapter extends BaseAPIFormat { // ─── Model-family sampling parameters ─────────────────────────────── - private getSamplingParams(): SamplingParams { - const id = this.modelId.toLowerCase(); - - if (id.includes("qwen")) { - // Qwen3 Instruct recommended settings - return { temperature: 0.7, top_p: 0.8, top_k: 20, min_p: 0.0, repetition_penalty: 1.05 }; - } - if (id.includes("deepseek")) { - return { temperature: 0.6, top_p: 0.95, top_k: 40, min_p: 0.0, repetition_penalty: 1.0 }; - } - if (id.includes("llama")) { - return { temperature: 0.7, top_p: 0.9, top_k: 40, min_p: 0.05, repetition_penalty: 1.1 }; - } - if (id.includes("mistral")) { - return { temperature: 0.7, top_p: 0.9, top_k: 50, min_p: 0.0, repetition_penalty: 1.0 }; - } - // Generic defaults + protected getSamplingParams(): SamplingParams { + // Generic defaults; model-family subclasses override with tuned values return { temperature: 0.7, top_p: 0.9, top_k: 40, min_p: 0.0, repetition_penalty: 1.0 }; } // ─── System prompt guidance ───────────────────────────────────────── - private buildSystemGuidance(toolCount: number): string { + protected getToolGuidanceHeader(): string { + return ` + +4. TOOL CALLING REQUIREMENTS:`; + } + + protected buildSystemGuidance(toolCount: number): string { let guidance = ` IMPORTANT INSTRUCTIONS FOR THIS MODEL: @@ -211,25 +190,7 @@ IMPORTANT INSTRUCTIONS FOR THIS MODEL: - If you called a Glob/Search and got files, READ important files next, then ANALYZE, then SUGGEST improvements.`; if (toolCount > 0) { - const isQwen = this.modelId.toLowerCase().includes("qwen"); - - if (isQwen) { - guidance += ` - -4. TOOL CALLING FORMAT (CRITICAL FOR QWEN): -You MUST use proper OpenAI-style function calling. Do NOT output tool calls as XML text. -When you want to call a tool, use the API's tool_calls mechanism, NOT text like . -The tool calls must be structured JSON in the API response, not XML in your text output. - -If you cannot use structured tool_calls, format as JSON: -{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}} - -5. TOOL PARAMETER REQUIREMENTS:`; - } else { - guidance += ` - -4. TOOL CALLING REQUIREMENTS:`; - } + guidance += this.getToolGuidanceHeader(); guidance += ` - When calling tools, you MUST include ALL required parameters. Incomplete tool calls will fail. diff --git a/packages/cli/src/adapters/local-deepseek-format-adapter.ts b/packages/cli/src/adapters/local-deepseek-format-adapter.ts new file mode 100644 index 00000000..744742fc --- /dev/null +++ b/packages/cli/src/adapters/local-deepseek-format-adapter.ts @@ -0,0 +1,13 @@ +/** + * LocalDeepSeekFormatAdapter — DeepSeek-family local model adapter. + * + * Overrides sampling parameters for DeepSeek models. + */ + +import { LocalModelAdapter, type SamplingParams } from "./local-adapter.js"; + +export class LocalDeepSeekFormatAdapter extends LocalModelAdapter { + protected override getSamplingParams(): SamplingParams { + return { temperature: 0.6, top_p: 0.95, top_k: 40, min_p: 0.0, repetition_penalty: 1.0 }; + } +} diff --git a/packages/cli/src/adapters/local-llama-format-adapter.ts b/packages/cli/src/adapters/local-llama-format-adapter.ts new file mode 100644 index 00000000..c22eb341 --- /dev/null +++ b/packages/cli/src/adapters/local-llama-format-adapter.ts @@ -0,0 +1,13 @@ +/** + * LocalLlamaFormatAdapter — Llama-family local model adapter. + * + * Overrides sampling parameters for Llama models. + */ + +import { LocalModelAdapter, type SamplingParams } from "./local-adapter.js"; + +export class LocalLlamaFormatAdapter extends LocalModelAdapter { + protected override getSamplingParams(): SamplingParams { + return { temperature: 0.7, top_p: 0.9, top_k: 40, min_p: 0.05, repetition_penalty: 1.1 }; + } +} diff --git a/packages/cli/src/adapters/local-mistral-format-adapter.ts b/packages/cli/src/adapters/local-mistral-format-adapter.ts new file mode 100644 index 00000000..0274dc9a --- /dev/null +++ b/packages/cli/src/adapters/local-mistral-format-adapter.ts @@ -0,0 +1,13 @@ +/** + * LocalMistralFormatAdapter — Mistral-family local model adapter. + * + * Overrides sampling parameters for Mistral models. + */ + +import { LocalModelAdapter, type SamplingParams } from "./local-adapter.js"; + +export class LocalMistralFormatAdapter extends LocalModelAdapter { + protected override getSamplingParams(): SamplingParams { + return { temperature: 0.7, top_p: 0.9, top_k: 50, min_p: 0.0, repetition_penalty: 1.0 }; + } +} diff --git a/packages/cli/src/adapters/local-qwen-format-adapter.ts b/packages/cli/src/adapters/local-qwen-format-adapter.ts new file mode 100644 index 00000000..04d48f30 --- /dev/null +++ b/packages/cli/src/adapters/local-qwen-format-adapter.ts @@ -0,0 +1,44 @@ +/** + * LocalQwenFormatAdapter — Qwen-family local model adapter. + * + * Overrides sampling parameters (Qwen3 Instruct recommended settings) + * and adds /no_think toggle support for disabling Qwen thinking mode. + */ + +import { LocalModelAdapter, type SamplingParams } from "./local-adapter.js"; +import { log } from "../logger.js"; + +export class LocalQwenFormatAdapter extends LocalModelAdapter { + protected override getSamplingParams(): SamplingParams { + // Qwen3 Instruct recommended settings + return { temperature: 0.7, top_p: 0.8, top_k: 20, min_p: 0.0, repetition_penalty: 1.05 }; + } + + protected override getToolGuidanceHeader(): string { + return ` + +4. TOOL CALLING FORMAT (CRITICAL FOR QWEN): +You MUST use proper OpenAI-style function calling. Do NOT output tool calls as XML text. +When you want to call a tool, use the API's tool_calls mechanism, NOT text like . +The tool calls must be structured JSON in the API response, not XML in your text output. + +If you cannot use structured tool_calls, format as JSON: +{"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}} + +5. TOOL PARAMETER REQUIREMENTS:`; + } + + override convertMessages(claudeRequest: any, filterIdentityFn?: (s: string) => string): any[] { + const messages = super.convertMessages(claudeRequest, filterIdentityFn); + + // Qwen /no_think toggle + if (process.env.CLAUDISH_QWEN_NO_THINK === "1") { + if (messages.length > 0 && messages[0].role === "system") { + messages[0].content = "/no_think\n\n" + messages[0].content; + log(`[${this.getName()}] Added /no_think to disable Qwen thinking mode`); + } + } + + return messages; + } +} diff --git a/packages/cli/src/adapters/minimax-model-dialect.ts b/packages/cli/src/adapters/minimax-model-dialect.ts index 25b73c14..142625ff 100644 --- a/packages/cli/src/adapters/minimax-model-dialect.ts +++ b/packages/cli/src/adapters/minimax-model-dialect.ts @@ -8,11 +8,12 @@ * - Vision: not supported — supportsVision() returns false so ComposedHandler strips images */ -import { BaseAPIFormat, AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult } from "./base-api-format.js"; import { log } from "../logger.js"; import { lookupModel } from "./model-catalog.js"; -export class MiniMaxModelDialect extends BaseAPIFormat { +export class MiniMaxModelDialect extends BaseModelDialect { processTextContent(textContent: string, accumulatedText: string): AdapterResult { // MiniMax interleaved thinking is handled by the model return { @@ -66,10 +67,6 @@ export class MiniMaxModelDialect extends BaseAPIFormat { return lookupModel(this.modelId)?.supportsVision ?? false; } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "minimax"); - } - getName(): string { return "MiniMaxModelDialect"; } diff --git a/packages/cli/src/adapters/model-dialect.ts b/packages/cli/src/adapters/model-dialect.ts index 768c8571..c8d893dc 100644 --- a/packages/cli/src/adapters/model-dialect.ts +++ b/packages/cli/src/adapters/model-dialect.ts @@ -24,9 +24,6 @@ export interface ModelDialect { /** Maximum tool name length, or null if unlimited */ getToolNameLimit(): number | null; - /** Check if this dialect handles the given model ID */ - shouldHandle(modelId: string): boolean; - /** Dialect name for logging */ getName(): string; } diff --git a/packages/cli/src/adapters/ollama-api-format.ts b/packages/cli/src/adapters/ollama-api-format.ts index a0957b35..50cf9a2d 100644 --- a/packages/cli/src/adapters/ollama-api-format.ts +++ b/packages/cli/src/adapters/ollama-api-format.ts @@ -24,10 +24,6 @@ export class OllamaAPIFormat extends BaseAPIFormat { }; } - shouldHandle(_modelId: string): boolean { - return false; // Not auto-selected; always explicitly passed - } - getName(): string { return "OllamaAPIFormat"; } diff --git a/packages/cli/src/adapters/openai-api-format.ts b/packages/cli/src/adapters/openai-api-format.ts index 86b79b00..20088630 100644 --- a/packages/cli/src/adapters/openai-api-format.ts +++ b/packages/cli/src/adapters/openai-api-format.ts @@ -59,10 +59,6 @@ export class OpenAIAPIFormat extends BaseAPIFormat { return request; } - shouldHandle(modelId: string): boolean { - return modelId.startsWith("oai/") || modelId.includes("o1") || modelId.includes("o3"); - } - getName(): string { return "OpenAIAPIFormat"; } diff --git a/packages/cli/src/adapters/openrouter-api-format.ts b/packages/cli/src/adapters/openrouter-api-format.ts index c78e9db4..2993ae9e 100644 --- a/packages/cli/src/adapters/openrouter-api-format.ts +++ b/packages/cli/src/adapters/openrouter-api-format.ts @@ -11,7 +11,7 @@ */ import { BaseAPIFormat, type AdapterResult } from "./base-api-format.js"; -import { DialectManager } from "./dialect-manager.js"; +import { resolveModelDialect } from "../providers/provider-profiles.js"; import { removeUriFormat } from "../transform.js"; import { log } from "../logger.js"; @@ -22,8 +22,7 @@ export class OpenRouterAPIFormat extends BaseAPIFormat { super(modelId); // Get model-specific dialect (GrokModelDialect, GeminiAPIFormat, etc.) - const manager = new DialectManager(modelId); - this.innerAdapter = manager.getAdapter(); + this.innerAdapter = resolveModelDialect(modelId); } /** Synchronous reasoning support check via model ID patterns */ @@ -44,10 +43,6 @@ export class OpenRouterAPIFormat extends BaseAPIFormat { return this.innerAdapter.processTextContent(textContent, accumulatedText); } - shouldHandle(modelId: string): boolean { - return true; // Always used explicitly - } - getName(): string { return `OpenRouterAPIFormat(${this.innerAdapter.getName()})`; } @@ -57,40 +52,8 @@ export class OpenRouterAPIFormat extends BaseAPIFormat { this.innerAdapter.reset(); } - // ─── Message conversion with model-specific system prompts ───────── - - override convertMessages(claudeRequest: any, filterIdentityFn?: (s: string) => string): any[] { - // Use default OpenAI conversion - const messages = super.convertMessages(claudeRequest, filterIdentityFn); - - // Add model-specific system prompt tweaks - if (this.modelId.includes("grok") || this.modelId.includes("x-ai")) { - const msg = - "IMPORTANT: When calling tools, you MUST use the OpenAI tool_calls format with JSON. NEVER use XML format like ."; - this.appendToSystemPrompt(messages, msg); - } - - if (this.modelId.includes("gemini") || this.modelId.includes("google/")) { - const geminiMsg = `CRITICAL INSTRUCTION FOR OUTPUT FORMAT: -1. Keep ALL internal reasoning INTERNAL. Never output your thought process as visible text. -2. Do NOT start responses with phrases like "Wait, I'm...", "Let me think...", "Okay, so...", "First, I need to..." -3. Do NOT output numbered planning steps or internal debugging statements. -4. Only output: final responses, tool calls, and code. Nothing else. -5. When calling tools, proceed directly without announcing your intentions. -6. Your internal thinking should use the reasoning/thinking API, not visible text output.`; - this.appendToSystemPrompt(messages, geminiMsg); - } - - return messages; - } - - private appendToSystemPrompt(messages: any[], text: string): void { - if (messages.length > 0 && messages[0].role === "system") { - messages[0].content += "\n\n" + text; - } else { - messages.unshift({ role: "system", content: text }); - } - } + // Model-specific system prompts are handled by the inner dialect's prepareRequest() + // (GrokModelDialect injects tool format instruction, GeminiModelDialect injects output format) // ─── Tool conversion with uri format removal ────────────────────── diff --git a/packages/cli/src/adapters/qwen-model-dialect.ts b/packages/cli/src/adapters/qwen-model-dialect.ts index c314b9f7..32afb4ce 100644 --- a/packages/cli/src/adapters/qwen-model-dialect.ts +++ b/packages/cli/src/adapters/qwen-model-dialect.ts @@ -6,7 +6,8 @@ * - Maps thinking → enable_thinking + thinking_budget params */ -import { BaseAPIFormat, AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult } from "./base-api-format.js"; import { log } from "../logger.js"; // Qwen special tokens that should be stripped from output @@ -18,7 +19,7 @@ const QWEN_SPECIAL_TOKENS = [ "assistant\n", // Role marker that sometimes leaks ]; -export class QwenModelDialect extends BaseAPIFormat { +export class QwenModelDialect extends BaseModelDialect { processTextContent(textContent: string, accumulatedText: string): AdapterResult { // Strip Qwen special tokens that may leak through // This can happen when the model gets confused and outputs its chat template @@ -71,10 +72,6 @@ export class QwenModelDialect extends BaseAPIFormat { return request; } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "qwen") || matchesModelFamily(modelId, "alibaba"); - } - getName(): string { return "QwenModelDialect"; } diff --git a/packages/cli/src/adapters/xiaomi-model-dialect.ts b/packages/cli/src/adapters/xiaomi-model-dialect.ts index 6c23457c..4fe84758 100644 --- a/packages/cli/src/adapters/xiaomi-model-dialect.ts +++ b/packages/cli/src/adapters/xiaomi-model-dialect.ts @@ -7,11 +7,12 @@ * - Context window comes dynamically from OpenRouter model catalog */ -import { BaseAPIFormat, AdapterResult, matchesModelFamily } from "./base-api-format.js"; +import { BaseModelDialect } from "./base-model-dialect.js"; +import type { AdapterResult } from "./base-api-format.js"; import { log } from "../logger.js"; import { lookupModel } from "./model-catalog.js"; -export class XiaomiModelDialect extends BaseAPIFormat { +export class XiaomiModelDialect extends BaseModelDialect { processTextContent(textContent: string, accumulatedText: string): AdapterResult { return { cleanedText: textContent, @@ -40,10 +41,6 @@ export class XiaomiModelDialect extends BaseAPIFormat { return request; } - shouldHandle(modelId: string): boolean { - return matchesModelFamily(modelId, "xiaomi") || matchesModelFamily(modelId, "mimo"); - } - getName(): string { return "XiaomiModelDialect"; } diff --git a/packages/cli/src/claude-runner.ts b/packages/cli/src/claude-runner.ts index cc524860..0f338ac1 100644 --- a/packages/cli/src/claude-runner.ts +++ b/packages/cli/src/claude-runner.ts @@ -5,27 +5,10 @@ import { tmpdir, homedir } from "node:os"; import { join, basename } from "node:path"; import { ENV } from "./config.js"; import type { ClaudishConfig } from "./types.js"; -import { parseModelSpec } from "./providers/model-parser.js"; import type { MtmDiagRunner } from "./pty-diag-runner.js"; // Backward-compat alias type PtyDiagRunner = MtmDiagRunner; -/** - * Check if any resolved model mapping targets a native Anthropic model (claude-*). - * When true, placeholder auth tokens must NOT be set — Claude Code needs its real - * subscription credentials so NativeHandler can forward them to api.anthropic.com. - */ -function hasNativeAnthropicMapping(config: ClaudishConfig): boolean { - const models = [ - config.model, - config.modelOpus, - config.modelSonnet, - config.modelHaiku, - config.modelSubagent, - ]; - return models.some((m) => m && parseModelSpec(m).provider === "native-anthropic"); -} - // Use process.platform directly to ensure runtime evaluation // (module-level constants can be inlined by bundlers at build time) function isWindows(): boolean { @@ -251,7 +234,6 @@ function mergeUserSettingsIfPresent( export async function runClaudeWithProxy( config: ClaudishConfig, proxyUrl: string, - onCleanup?: () => void, ptyDiagRunner?: PtyDiagRunner | null ): Promise { // Use actual OpenRouter model ID (no translation) @@ -362,21 +344,16 @@ export async function runClaudeWithProxy( env[ENV.ANTHROPIC_MODEL] = modelId; env[ENV.ANTHROPIC_SMALL_FAST_MODEL] = modelId; } - if (hasNativeAnthropicMapping(config)) { - // Native Claude model detected — let Claude Code use its real subscription - // credentials. Don't set placeholders, but preserve any real keys the user has. - } else { - // Pure alternative mode: all models go through proxy providers - // Use placeholder to prevent Claude Code login dialog - env.ANTHROPIC_API_KEY = - process.env.ANTHROPIC_API_KEY || - "sk-ant-api03-placeholder-not-used-proxy-handles-auth-with-openrouter-key-xxxxxxxxxxxxxxxxxxxxx"; - - // Also set ANTHROPIC_AUTH_TOKEN to bypass login screen - // Claude Code checks both API_KEY and AUTH_TOKEN for authentication - env.ANTHROPIC_AUTH_TOKEN = - process.env.ANTHROPIC_AUTH_TOKEN || "placeholder-token-not-used-proxy-handles-auth"; - } + // Set placeholder credentials to prevent Claude Code login dialog. + // All models go through the proxy; real auth is handled per-provider there. + env.ANTHROPIC_API_KEY = + process.env.ANTHROPIC_API_KEY || + "sk-ant-api03-placeholder-not-used-proxy-handles-auth-with-openrouter-key-xxxxxxxxxxxxxxxxxxxxx"; + + // Also set ANTHROPIC_AUTH_TOKEN to bypass login screen + // Claude Code checks both API_KEY and AUTH_TOKEN for authentication + env.ANTHROPIC_AUTH_TOKEN = + process.env.ANTHROPIC_AUTH_TOKEN || "placeholder-token-not-used-proxy-handles-auth"; } // Helper function to log messages (respects quiet flag) @@ -386,10 +363,6 @@ export async function runClaudeWithProxy( } }; - if (!config.monitor && hasNativeAnthropicMapping(config)) { - log("[claudish] Native Claude model detected — using Claude Code subscription credentials"); - } - if (config.interactive) { log(`\n[claudish] Model: ${modelDisplayName}\n`); } else { @@ -441,7 +414,7 @@ export async function runClaudeWithProxy( }); // Handle process termination signals (includes cleanup) - setupSignalHandlers(proc, tempSettingsPath, config.quiet, onCleanup); + setupSignalHandlers(proc, tempSettingsPath, config.quiet); // Wait for claude to exit exitCode = await new Promise((resolve) => { @@ -468,7 +441,6 @@ function setupSignalHandlers( proc: ChildProcess, tempSettingsPath: string, quiet: boolean, - onCleanup?: () => void ): void { // Windows only supports SIGINT and SIGTERM reliably // SIGHUP doesn't exist on Windows @@ -482,14 +454,6 @@ function setupSignalHandlers( console.log(`\n[claudish] Received ${signal}, shutting down...`); } proc.kill(); - // Run optional cleanup (e.g. close diag tmux pane) before exit - if (onCleanup) { - try { - onCleanup(); - } catch { - // Ignore cleanup errors - } - } // Clean up temp settings file try { unlinkSync(tempSettingsPath); diff --git a/packages/cli/src/cli.ts b/packages/cli/src/cli.ts index bc86abc0..b39490f9 100644 --- a/packages/cli/src/cli.ts +++ b/packages/cli/src/cli.ts @@ -1395,9 +1395,9 @@ async function probeModelRouting(models: string[], jsonOutput: boolean): Promise const providerName = firstReadyRoute.provider; // Resolve model name from the model spec (strip provider prefix if present) - const { resolveRemoteProvider } = await import("./providers/remote-provider-registry.js"); - const resolvedSpec = resolveRemoteProvider(firstReadyRoute.modelSpec); - const modelName = resolvedSpec?.modelName || parsed.model; + const { parseModelSpec: parseSpec } = await import("./providers/model-parser.js"); + const resolvedSpec = parseSpec(firstReadyRoute.modelSpec); + const modelName = resolvedSpec.model || parsed.model; // Determine format adapter from provider name (mirrors provider-profiles.ts) let formatAdapterName = "OpenAIAPIFormat"; @@ -1430,10 +1430,9 @@ async function probeModelRouting(models: string[], jsonOutput: boolean): Promise declaredStreamFormat = "openai-sse"; } - // Get model dialect via DialectManager - const { DialectManager } = await import("./adapters/dialect-manager.js"); - const adapterManager = new DialectManager(modelName); - const modelTranslator = adapterManager.getAdapter(); + // Get model dialect via resolveModelDialect + const { resolveModelDialect } = await import("./providers/provider-profiles.js"); + const modelTranslator = resolveModelDialect(modelName); const modelTranslatorName = modelTranslator.getName(); // Transport overrides (aggregators that normalize responses to openai-sse) diff --git a/packages/cli/src/e2e-glm-adapter.test.ts b/packages/cli/src/e2e-glm-adapter.test.ts index 657930f0..25282be6 100644 --- a/packages/cli/src/e2e-glm-adapter.test.ts +++ b/packages/cli/src/e2e-glm-adapter.test.ts @@ -3,7 +3,7 @@ * * Validates: * 1. GLMModelDialect model detection, context windows, and vision support - * 2. DialectManager correctly selects GLMModelDialect for GLM models + * 2. resolveModelDialect correctly selects GLMModelDialect for GLM models * 3. ComposedHandler three-layer architecture — model dialect provides model-specific * overrides (context window, vision, prepareRequest) even when a provider format * (LiteLLMAPIFormat, OpenRouterAPIFormat) is set as the explicit adapter @@ -11,46 +11,15 @@ import { describe, test, expect } from "bun:test"; import { GLMModelDialect } from "./adapters/glm-model-dialect.js"; -import { DialectManager } from "./adapters/dialect-manager.js"; +import { resolveModelDialect } from "./providers/provider-profiles.js"; import { LiteLLMAPIFormat } from "./adapters/litellm-api-format.js"; import { DefaultAPIFormat } from "./adapters/base-api-format.js"; // ─── Group 1: GLMModelDialect unit tests ───────────────────────────────────── -describe("GLMModelDialect — Model Detection", () => { +describe("GLMModelDialect — Basics", () => { const adapter = new GLMModelDialect("glm-5"); - test("should handle glm-5", () => { - expect(adapter.shouldHandle("glm-5")).toBe(true); - }); - - test("should handle glm-4-plus", () => { - expect(adapter.shouldHandle("glm-4-plus")).toBe(true); - }); - - test("should handle glm-4-flash", () => { - expect(adapter.shouldHandle("glm-4-flash")).toBe(true); - }); - - test("should handle glm-4-long", () => { - expect(adapter.shouldHandle("glm-4-long")).toBe(true); - }); - - test("should handle glm-3-turbo", () => { - expect(adapter.shouldHandle("glm-3-turbo")).toBe(true); - }); - - test("should handle zhipu/ prefixed models", () => { - expect(adapter.shouldHandle("zhipu/glm-5")).toBe(true); - }); - - test("should NOT handle non-GLM models", () => { - expect(adapter.shouldHandle("gpt-4o")).toBe(false); - expect(adapter.shouldHandle("gemini-2.0-flash")).toBe(false); - expect(adapter.shouldHandle("deepseek-r1")).toBe(false); - expect(adapter.shouldHandle("grok-3")).toBe(false); - }); - test("should return correct adapter name", () => { expect(adapter.getName()).toBe("GLMModelDialect"); }); @@ -134,50 +103,46 @@ describe("GLMModelDialect — processTextContent", () => { }); }); -// ─── Group 2: DialectManager selects GLMModelDialect ───────────────────────── +// ─── Group 2: resolveModelDialect selects GLMModelDialect ───────────────────── -describe("DialectManager — GLM routing", () => { +describe("resolveModelDialect — GLM routing", () => { test("selects GLMModelDialect for glm-5", () => { - const manager = new DialectManager("glm-5"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("glm-5"); expect(adapter.getName()).toBe("GLMModelDialect"); }); test("selects GLMModelDialect for glm-4-long", () => { - const manager = new DialectManager("glm-4-long"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("glm-4-long"); expect(adapter.getName()).toBe("GLMModelDialect"); }); test("does NOT select GLMModelDialect for gpt-4o", () => { - const manager = new DialectManager("gpt-4o"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("gpt-4o"); - expect(adapter.getName()).not.toBe("GLMModelDialect"); + expect(adapter).toBeNull(); }); test("needsTransformation returns true for GLM models", () => { - const manager = new DialectManager("glm-5"); - expect(manager.needsTransformation()).toBe(true); + const adapter = resolveModelDialect("glm-5"); + expect(adapter).not.toBeNull(); }); }); // ─── Group 3: Three-layer adapter architecture ─────────────────────────────── // // When a format adapter (LiteLLMAPIFormat) is the explicit adapter, the model -// dialect (GLMModelDialect) should still be resolved by DialectManager for +// dialect (GLMModelDialect) should still be resolved by resolveModelDialect for // model-specific concerns. describe("Three-layer adapter — model dialect overrides format adapter", () => { - test("DialectManager resolves GLMModelDialect even when LiteLLMAPIFormat would be used", () => { + test("resolveModelDialect resolves GLMModelDialect even when LiteLLMAPIFormat would be used", () => { // Simulate what ComposedHandler does: // 1. Explicit adapter = LiteLLMAPIFormat (L1 wire format) - // 2. DialectManager.getAdapter() = GLMModelDialect (L2 model quirks) + // 2. resolveModelDialect() = GLMModelDialect (L2 model quirks) const litellmAdapter = new LiteLLMAPIFormat("glm-5", "https://example.com"); - const adapterManager = new DialectManager("glm-5"); - const modelAdapter = adapterManager.getAdapter(); + const modelAdapter = resolveModelDialect("glm-5"); // Format adapter handles wire format / transport expect(litellmAdapter.getName()).toBe("LiteLLMAPIFormat"); @@ -196,33 +161,29 @@ describe("Three-layer adapter — model dialect overrides format adapter", () => }); test("model dialect provides correct context window for glm-4-long via LiteLLM", () => { - const adapterManager = new DialectManager("glm-4-long"); - const modelAdapter = adapterManager.getAdapter(); + const modelAdapter = resolveModelDialect("glm-4-long"); expect(modelAdapter.getName()).toBe("GLMModelDialect"); expect(modelAdapter.getContextWindow()).toBe(1_000_000); }); test("model dialect correctly reports no vision for glm-4-flash via LiteLLM", () => { - const adapterManager = new DialectManager("glm-4-flash"); - const modelAdapter = adapterManager.getAdapter(); + const modelAdapter = resolveModelDialect("glm-4-flash"); expect(modelAdapter.getName()).toBe("GLMModelDialect"); expect(modelAdapter.supportsVision()).toBe(false); }); test("non-GLM model via LiteLLM falls back to DefaultAPIFormat", () => { - const adapterManager = new DialectManager("some-unknown-model"); - const modelAdapter = adapterManager.getAdapter(); + const modelAdapter = resolveModelDialect("some-unknown-model"); // Should be DefaultAPIFormat, not GLMModelDialect - expect(modelAdapter.getName()).toBe("DefaultAPIFormat"); + expect(modelAdapter).toBeNull(); }); test("model dialect strips thinking, format adapter does not", () => { const litellmAdapter = new LiteLLMAPIFormat("glm-5", "https://example.com"); - const adapterManager = new DialectManager("glm-5"); - const modelAdapter = adapterManager.getAdapter(); + const modelAdapter = resolveModelDialect("glm-5"); // Format adapter does not strip thinking (no override) const request1 = { model: "glm-5", thinking: { budget: 10000 }, messages: [] }; diff --git a/packages/cli/src/e2e-model-catalog.test.ts b/packages/cli/src/e2e-model-catalog.test.ts index 10399b4c..c5f5cfaf 100644 --- a/packages/cli/src/e2e-model-catalog.test.ts +++ b/packages/cli/src/e2e-model-catalog.test.ts @@ -15,7 +15,7 @@ import { lookupModel } from "./adapters/model-catalog.js"; import { MiniMaxModelDialect } from "./adapters/minimax-model-dialect.js"; import { GLMModelDialect } from "./adapters/glm-model-dialect.js"; import { GrokModelDialect } from "./adapters/grok-model-dialect.js"; -import { DialectManager } from "./adapters/dialect-manager.js"; +import { resolveModelDialect } from "./providers/provider-profiles.js"; import { AnthropicAPIFormat } from "./adapters/anthropic-api-format.js"; const MINIMAX_API_KEY = @@ -181,40 +181,34 @@ describe("Group 2: GrokModelDialect — catalog integration", () => { }); }); -describe("Group 2: DialectManager — correct dialect selection", () => { +describe("Group 2: resolveModelDialect — correct dialect selection", () => { test("selects MiniMaxModelDialect for MiniMax-M2.7", () => { - const manager = new DialectManager("MiniMax-M2.7"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("MiniMax-M2.7"); expect(adapter.getName()).toBe("MiniMaxModelDialect"); }); test("selects GLMModelDialect for glm-5", () => { - const manager = new DialectManager("glm-5"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("glm-5"); expect(adapter.getName()).toBe("GLMModelDialect"); }); test("selects GrokModelDialect for grok-4", () => { - const manager = new DialectManager("grok-4"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("grok-4"); expect(adapter.getName()).toBe("GrokModelDialect"); }); test("selects GrokModelDialect for x-ai/grok-4-fast", () => { - const manager = new DialectManager("x-ai/grok-4-fast"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("x-ai/grok-4-fast"); expect(adapter.getName()).toBe("GrokModelDialect"); }); test("selects MiniMaxModelDialect for minimax-m2.5", () => { - const manager = new DialectManager("minimax-m2.5"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("minimax-m2.5"); expect(adapter.getName()).toBe("MiniMaxModelDialect"); }); test("returns DefaultAPIFormat for unknown model", () => { - const manager = new DialectManager("totally-unknown-model-xyz"); - const adapter = manager.getAdapter(); + const adapter = resolveModelDialect("totally-unknown-model-xyz"); expect(adapter.getName()).toBe("DefaultAPIFormat"); }); }); diff --git a/packages/cli/src/format-translation.test.ts b/packages/cli/src/format-translation.test.ts index bdc4e858..05def4ef 100644 --- a/packages/cli/src/format-translation.test.ts +++ b/packages/cli/src/format-translation.test.ts @@ -582,15 +582,15 @@ describe("Model Adapter Quirks", () => { expect(request.thinking).toBeUndefined(); }); - test("AdapterManager selects correct adapter for model IDs", async () => { - const { DialectManager } = await import("./adapters/dialect-manager.js"); + test("resolveModelDialect selects correct adapter for model IDs", async () => { + const { resolveModelDialect } = await import("./providers/provider-profiles.js"); - expect(new DialectManager("glm-5").getAdapter().getName()).toBe("GLMModelDialect"); - expect(new DialectManager("grok-3").getAdapter().getName()).toBe("GrokModelDialect"); - expect(new DialectManager("minimax-m2.5").getAdapter().getName()).toBe("MiniMaxModelDialect"); - expect(new DialectManager("qwen3.5-plus").getAdapter().getName()).toBe("QwenModelDialect"); - expect(new DialectManager("deepseek-r1").getAdapter().getName()).toBe("DeepSeekModelDialect"); - expect(new DialectManager("unknown-model").getAdapter().getName()).toBe("DefaultAPIFormat"); + expect(resolveModelDialect("glm-5").getName()).toBe("GLMModelDialect"); + expect(resolveModelDialect("grok-3").getName()).toBe("GrokModelDialect"); + expect(resolveModelDialect("minimax-m2.5").getName()).toBe("MiniMaxModelDialect"); + expect(resolveModelDialect("qwen3.5-plus").getName()).toBe("QwenModelDialect"); + expect(resolveModelDialect("deepseek-r1").getName()).toBe("DeepSeekModelDialect"); + expect(resolveModelDialect("unknown-model")).toBeNull(); }); }); @@ -627,25 +627,13 @@ describe("APIFormat: getStreamFormat()", () => { expect(new CodexAPIFormat("codex-mini").getStreamFormat()).toBe("openai-responses-sse"); }); - test("GLMModelDialect inherits openai-sse (uses OpenAI-compat API)", async () => { + test("GLMModelDialect returns undefined (dialect has no stream format opinion)", async () => { const { GLMModelDialect } = await import("./adapters/glm-model-dialect.js"); - expect(new GLMModelDialect("glm-5").getStreamFormat()).toBe("openai-sse"); + expect(new GLMModelDialect("glm-5").getStreamFormat()).toBeUndefined(); }); }); describe("CodexAdapter", () => { - test("shouldHandle returns true for codex models", async () => { - const { CodexAPIFormat } = await import("./adapters/codex-api-format.js"); - expect(new CodexAPIFormat("codex-mini").shouldHandle("codex-mini")).toBe(true); - expect(new CodexAPIFormat("codex-mini").shouldHandle("codex-davinci-002")).toBe(true); - }); - - test("shouldHandle returns false for non-codex models", async () => { - const { CodexAPIFormat } = await import("./adapters/codex-api-format.js"); - expect(new CodexAPIFormat("gpt-5.4").shouldHandle("gpt-5.4")).toBe(false); - expect(new CodexAPIFormat("o3").shouldHandle("o3")).toBe(false); - }); - test("getStreamFormat returns openai-responses-sse", async () => { const { CodexAPIFormat } = await import("./adapters/codex-api-format.js"); expect(new CodexAPIFormat("codex-mini").getStreamFormat()).toBe("openai-responses-sse"); @@ -656,9 +644,9 @@ describe("CodexAdapter", () => { expect(new CodexAPIFormat("codex-mini").getName()).toBe("CodexAPIFormat"); }); - test("AdapterManager selects CodexAPIFormat for codex-mini", async () => { - const { DialectManager } = await import("./adapters/dialect-manager.js"); - expect(new DialectManager("codex-mini").getAdapter().getName()).toBe("CodexAPIFormat"); + test("codex-mini has no dialect (wire format only)", async () => { + const { resolveModelDialect } = await import("./providers/provider-profiles.js"); + expect(resolveModelDialect("codex-mini")).toBeNull(); }); }); @@ -669,45 +657,33 @@ describe("ModelDialect interface compliance", () => { expect(typeof t.getContextWindow()).toBe("number"); expect(typeof t.supportsVision()).toBe("boolean"); expect(typeof t.prepareRequest).toBe("function"); - expect(typeof t.shouldHandle).toBe("function"); expect(typeof t.getName).toBe("function"); }); }); // ─── ProviderProfile Table Tests ───────────────────────────────────────────── -describe("ProviderProfile table completeness", () => { - test("all expected providers are registered", async () => { - const { PROVIDER_PROFILES } = await import("./providers/provider-profiles.js"); +describe("Provider coverage completeness", () => { + test("all expected providers create handlers via createHandlerForProvider", async () => { + const { createHandlerForProvider } = await import("./providers/provider-profiles.js"); + const { getProviderByName } = await import("./providers/provider-definitions.js"); + // Providers wired through createHandlerForProvider (excludes vertex which needs env config) const expectedProviders = [ - "gemini", - "gemini-codeassist", - "openai", - "minimax", - "minimax-coding", - "kimi", - "kimi-coding", - "zai", - "glm", - "glm-coding", - "opencode-zen", - "opencode-zen-go", + "google", "gemini-codeassist", "openai", + "minimax", "minimax-coding", "kimi", "kimi-coding", "zai", + "glm", "glm-coding", "opencode-zen", "opencode-zen-go", "ollamacloud", - "litellm", - "vertex", ]; - for (const provider of expectedProviders) { - expect(PROVIDER_PROFILES).toHaveProperty(provider); - } - }); - - test("each profile has a createHandler function", async () => { - const { PROVIDER_PROFILES } = await import("./providers/provider-profiles.js"); - - for (const [name, profile] of Object.entries(PROVIDER_PROFILES)) { - expect(typeof profile.createHandler).toBe("function"); + for (const name of expectedProviders) { + const def = getProviderByName(name); + expect(def).toBeDefined(); + const handler = createHandlerForProvider( + def!, "test-model", "test-key", "test-target", 4000, + { isInteractive: false, invocationMode: "explicit-model" as const }, + ); + expect(handler).not.toBeNull(); } }); }); diff --git a/packages/cli/src/handlers/composed-handler.ts b/packages/cli/src/handlers/composed-handler.ts index 2b7c2de0..eb6ca622 100644 --- a/packages/cli/src/handlers/composed-handler.ts +++ b/packages/cli/src/handlers/composed-handler.ts @@ -20,10 +20,9 @@ import type { Context } from "hono"; import type { ModelHandler } from "./types.js"; import type { ProviderTransport } from "../providers/transport/types.js"; import type { BaseAPIFormat } from "../adapters/base-api-format.js"; -// Alias for readability within this file -type BaseModelAdapter = BaseAPIFormat; -import { DialectManager } from "../adapters/dialect-manager.js"; -import { MiddlewareManager, GeminiThoughtSignatureMiddleware } from "../middleware/index.js"; +import type { BaseModelDialect } from "../adapters/base-model-dialect.js"; +import { resolveModelDialect } from "../providers/provider-profiles.js"; +import { MiddlewareManager } from "../middleware/index.js"; import { TokenTracker } from "./shared/token-tracker.js"; import { transformOpenAIToClaude } from "../transform.js"; import { filterIdentity } from "./shared/openai-compat.js"; @@ -50,15 +49,17 @@ function extractAuthHeaders(c: Context): VisionProxyAuthHeaders { export interface ComposedHandlerOptions { /** Override format selection — use this specific APIFormat instance */ - adapter?: BaseAPIFormat; + formatAdapter?: BaseAPIFormat; + /** Model dialect (GLM, Grok, etc.) — resolved by the provider registry */ + modelDialect?: BaseModelDialect | null; + /** Middleware instances — resolved by the provider registry */ + middlewares?: import("../middleware/types.js").ModelMiddleware[]; /** Tool schemas for validation (enables buffered tool call validation) */ toolSchemas?: any[]; - /** Token tracking strategy */ + /** Token tracking strategy override (transport provides default via tokenStrategy) */ tokenStrategy?: "standard" | "accumulate-both" | "delta-aware" | "actual-cost" | "local"; /** Summarize tool descriptions (for models with small context) */ summarizeTools?: boolean; - /** Whether the Gemini SSE stream wraps chunks in {response: {...}} (CodeAssist) */ - unwrapGeminiResponse?: boolean; /** Whether the current session is interactive (gates consent prompt). */ isInteractive?: boolean; /** How this handler was invoked (for stats). */ @@ -67,10 +68,9 @@ export interface ComposedHandlerOptions { export class ComposedHandler implements ModelHandler { private provider: ProviderTransport; - private adapterManager: DialectManager; - private explicitAdapter?: BaseModelAdapter; + private formatAdapter?: BaseAPIFormat; /** Model-specific adapter (GLM, Grok, etc.) — handles model quirks independent of provider */ - private modelAdapter?: BaseModelAdapter; + private modelAdapter?: BaseModelDialect; private middlewareManager: MiddlewareManager; private tokenTracker: TokenTracker; private targetModel: string; @@ -89,23 +89,19 @@ export class ComposedHandler implements ModelHandler { this.provider = provider; this.targetModel = targetModel; this.options = options; - this.explicitAdapter = options.adapter; + this.formatAdapter = options.formatAdapter; this.isInteractive = options.isInteractive ?? false; - // Initialize dialect manager for automatic dialect/format selection - this.adapterManager = new DialectManager(targetModel); - - // Always resolve model-specific adapter (GLM, Grok, DeepSeek, etc.) - // This handles model quirks independent of provider transport (LiteLLM, OpenRouter, etc.) - const resolvedModelAdapter = this.adapterManager.getAdapter(); - if (resolvedModelAdapter.getName() !== "DefaultAPIFormat") { - this.modelAdapter = resolvedModelAdapter; + // Model dialect (GLM, Grok, DeepSeek, etc.) — passed from resolver or resolved here as fallback + const dialect = options.modelDialect !== undefined ? options.modelDialect : resolveModelDialect(targetModel); + if (dialect) { + this.modelAdapter = dialect; } - // Initialize middleware (only register model-specific middleware when applicable) + // Initialize middleware — resolved by the provider registry this.middlewareManager = new MiddlewareManager(); - if (targetModel.includes("gemini") || targetModel.includes("google/")) { - this.middlewareManager.register(new GeminiThoughtSignatureMiddleware()); + if (options.middlewares) { + for (const mw of options.middlewares) this.middlewareManager.register(mw); } this.middlewareManager .initialize() @@ -120,9 +116,15 @@ export class ComposedHandler implements ModelHandler { }); } + /** Token strategy: options override, then transport, then default */ + private getTokenStrategy() { + return this.options.tokenStrategy ?? this.provider.tokenStrategy ?? "standard"; + } + /** Provider adapter — handles transport format (messages, tools, payload) */ - private getAdapter(): BaseModelAdapter { - return this.explicitAdapter || this.adapterManager.getAdapter(); + private getAdapter(): BaseAPIFormat { + if (!this.formatAdapter) throw new Error("No format adapter set"); + return this.formatAdapter; } /** Model context window — model adapter wins over provider adapter */ @@ -281,6 +283,15 @@ export class ComposedHandler implements ModelHandler { // Model adapter may also need to post-process (e.g., strip unsupported thinking params) if (this.modelAdapter && this.modelAdapter !== adapter) { this.modelAdapter.prepareRequest(requestPayload, claudeRequest); + // If the model dialect specifies a stricter tool name limit than the format adapter, + // apply truncation through the format adapter so the mapping is centralized. + const dialectLimit = this.modelAdapter.getToolNameLimit(); + if (dialectLimit !== null) { + adapter.truncateToolNames(requestPayload, dialectLimit); + if (requestPayload.messages) { + adapter.truncateToolNamesInMessages(requestPayload.messages, dialectLimit); + } + } } const toolNameMap = adapter.getToolNameMap(); @@ -376,7 +387,7 @@ export class ComposedHandler implements ModelHandler { http_status: 0, error_class, error_code, - token_strategy: this.options.tokenStrategy ?? "standard", + token_strategy: this.getTokenStrategy(), adapter_name: this.getActiveAdapterName(), middleware_names: this.middlewareManager.getActiveNames(this.targetModel), fallback_used: fallbackMeta !== undefined, @@ -443,7 +454,7 @@ export class ComposedHandler implements ModelHandler { http_status: retryResp.status, error_class, error_code, - token_strategy: this.options.tokenStrategy ?? "standard", + token_strategy: this.getTokenStrategy(), adapter_name: this.getActiveAdapterName(), middleware_names: this.middlewareManager.getActiveNames(this.targetModel), fallback_used: fallbackMeta !== undefined, @@ -484,7 +495,7 @@ export class ComposedHandler implements ModelHandler { http_status: 401, error_class, error_code, - token_strategy: this.options.tokenStrategy ?? "standard", + token_strategy: this.getTokenStrategy(), adapter_name: this.getActiveAdapterName(), middleware_names: this.middlewareManager.getActiveNames(this.targetModel), fallback_used: fallbackMeta !== undefined, @@ -546,7 +557,7 @@ export class ComposedHandler implements ModelHandler { http_status: response.status, error_class, error_code, - token_strategy: this.options.tokenStrategy ?? "standard", + token_strategy: this.getTokenStrategy(), adapter_name: this.getActiveAdapterName(), middleware_names: this.middlewareManager.getActiveNames(this.targetModel), fallback_used: fallbackMeta !== undefined, @@ -596,7 +607,7 @@ export class ComposedHandler implements ModelHandler { output_tokens: this.tokenTracker.getOutputTokens(), estimated_cost: this.tokenTracker.getTotalCost(), is_free_model: isFreeModel, - token_strategy: this.options.tokenStrategy ?? "standard", + token_strategy: this.getTokenStrategy(), adapter_name: this.getActiveAdapterName(), middleware_names: this.middlewareManager.getActiveNames(this.targetModel), fallback_used: fallbackMeta !== undefined, @@ -615,13 +626,13 @@ export class ComposedHandler implements ModelHandler { private handleStream( c: Context, response: Response, - adapter: BaseModelAdapter, + adapter: BaseAPIFormat, claudeRequest: any, toolNameMap?: Map, onComplete?: () => void ): Response { const onTokenUpdate = (input: number, output: number) => { - const strategy = this.options.tokenStrategy || "standard"; + const strategy = this.getTokenStrategy(); switch (strategy) { case "accumulate-both": this.tokenTracker.accumulateBoth(input, output); @@ -668,7 +679,8 @@ export class ComposedHandler implements ModelHandler { this.middlewareManager, onTokenUpdate, claudeRequest.tools, - toolNameMap + toolNameMap, + this.modelAdapter, ); case "openai-responses-sse": @@ -694,10 +706,11 @@ export class ComposedHandler implements ModelHandler { return createGeminiSseStream(c, response, { modelName: this.targetModel, adapter, + modelDialect: this.modelAdapter, middlewareManager: this.middlewareManager, onTokenUpdate, onToolCall, - unwrapResponse: this.options.unwrapGeminiResponse, + unwrapResponse: this.provider.unwrapResponse, }); } diff --git a/packages/cli/src/handlers/shared/format/openai-messages.ts b/packages/cli/src/handlers/shared/format/openai-messages.ts index 66632431..399f7eaf 100644 --- a/packages/cli/src/handlers/shared/format/openai-messages.ts +++ b/packages/cli/src/handlers/shared/format/openai-messages.ts @@ -24,17 +24,6 @@ export function convertMessagesToOpenAI( messages.push({ role: "system", content }); } - // Add instruction for Grok models to use proper tool format - if (modelId.includes("grok") || modelId.includes("x-ai")) { - const msg = - "IMPORTANT: When calling tools, you MUST use the OpenAI tool_calls format with JSON. NEVER use XML format like ."; - if (messages.length > 0 && messages[0].role === "system") { - messages[0].content += "\n\n" + msg; - } else { - messages.unshift({ role: "system", content: msg }); - } - } - if (req.messages) { for (const msg of req.messages) { if (msg.role === "user") processUserMessage(msg, messages, simpleFormat); diff --git a/packages/cli/src/handlers/shared/request-queue.ts b/packages/cli/src/handlers/shared/request-queue.ts new file mode 100644 index 00000000..4efd4a80 --- /dev/null +++ b/packages/cli/src/handlers/shared/request-queue.ts @@ -0,0 +1,191 @@ +/** + * RequestQueue — unified request handler for all transports. + * + * One config covers timeout, rate limiting, concurrency, and retry. + * Hooks are optional functions on the config, not separate strategy objects. + */ + +import { getLogLevel, log } from "../../logger.js"; + +// ---- Config ---- + +export interface QueueState { + consecutiveErrors: number; + baseDelayMs: number; + maxDelayMs: number; + currentDelayMs: number; +} + +export interface RequestQueueConfig { + /** Max parallel requests. 1 = serial. Default: 1. */ + maxParallel?: number; + /** Max queued requests. Default: 100. */ + maxQueueSize?: number; + /** Per-request timeout in ms. 0 = no timeout. */ + timeoutMs?: number; + /** Base inter-request delay in ms. 0 = no delay. */ + baseDelayMs?: number; + /** Max inter-request delay in ms. Default: 10000. */ + maxDelayMs?: number; + /** Delay before retry in ms. Default: 2000. */ + retryDelayMs?: number; + /** Override delay calculation. */ + calculateDelay?(state: QueueState): number; + /** Called after each response. Return true if rate limited. */ + onResponse?(response: Response, state: QueueState): boolean; + /** Called on 5xx. Return true to retry once. */ + shouldRetry?(response: Response, errorBody: string): boolean; +} + +export interface RequestQueueStats { + queueLength: number; + processing: boolean; + activeRequests: number; + maxParallel: number; + consecutiveErrors: number; + currentDelayMs: number; + totalProcessed: number; + totalErrors: number; +} + +// ---- RequestQueue ---- + +interface QueuedRequest { + fetchFn: () => Promise; + resolve: (response: Response) => void; + reject: (error: Error) => void; +} + +function defaultCalculateDelay(state: QueueState): number { + let d = state.baseDelayMs; + if (state.consecutiveErrors > 0) d = Math.min(d * (1 + state.consecutiveErrors * 0.5), state.maxDelayMs); + return d; +} + +export class RequestQueue { + private queue: QueuedRequest[] = []; + private processing = false; + private activeRequests = 0; + private lastRequestTime = 0; + private totalProcessed = 0; + private totalErrors = 0; + + private readonly label: string; + private readonly maxQueueSize: number; + private readonly timeoutMs: number; + private readonly retryDelayMs: number; + private readonly calculateDelay: (state: QueueState) => number; + private readonly onResponse: (response: Response, state: QueueState) => boolean; + private readonly shouldRetry?: (response: Response, errorBody: string) => boolean; + private state: QueueState; + maxParallel: number; + + constructor(label: string, config: RequestQueueConfig = {}) { + this.label = label; + this.maxQueueSize = config.maxQueueSize ?? 100; + this.maxParallel = config.maxParallel ?? 1; + this.timeoutMs = config.timeoutMs ?? 0; + this.retryDelayMs = config.retryDelayMs ?? 2000; + this.calculateDelay = config.calculateDelay ?? defaultCalculateDelay; + this.onResponse = config.onResponse ?? ((r, _s) => r.status === 429); + this.shouldRetry = config.shouldRetry; + this.state = { + consecutiveErrors: 0, + baseDelayMs: config.baseDelayMs ?? 0, + maxDelayMs: config.maxDelayMs ?? 10000, + currentDelayMs: config.baseDelayMs ?? 0, + }; + } + + async enqueue(fetchFn: () => Promise): Promise { + if (this.queue.length >= this.maxQueueSize) { + throw new Error(`${this.label} queue full (${this.queue.length}/${this.maxQueueSize}).`); + } + return new Promise((resolve, reject) => { + this.queue.push({ fetchFn, resolve, reject }); + if (!this.processing) this.processQueue(); + }); + } + + getStats(): RequestQueueStats { + return { + queueLength: this.queue.length, processing: this.processing, + activeRequests: this.activeRequests, maxParallel: this.maxParallel, + consecutiveErrors: this.state.consecutiveErrors, currentDelayMs: this.state.currentDelayMs, + totalProcessed: this.totalProcessed, totalErrors: this.totalErrors, + }; + } + + private async processQueue(): Promise { + if (this.processing) return; + this.processing = true; + while (this.queue.length > 0 && this.activeRequests < this.maxParallel) { + const request = this.queue.shift()!; + this.activeRequests++; + if (this.maxParallel === 1) { + await this.executeRequest(request); + } else { + this.executeRequest(request).catch(() => {}); + if (this.queue.length > 0) await new Promise(r => setTimeout(r, 100)); + } + } + this.processing = false; + } + + private async executeRequest(request: QueuedRequest): Promise { + try { + await this.waitForNextSlot(); + const doFetch = this.timeoutMs > 0 + ? () => Promise.race([ + request.fetchFn(), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Request timed out after ${this.timeoutMs}ms`)), this.timeoutMs) + ), + ]) + : request.fetchFn; + const response = await doFetch(); + this.lastRequestTime = Date.now(); + + if (this.onResponse(response, this.state)) { + this.totalErrors++; + this.state.consecutiveErrors++; + } else { + if (this.state.consecutiveErrors > 0) this.state.consecutiveErrors = 0; + if (this.state.currentDelayMs > this.state.baseDelayMs) { + this.state.currentDelayMs = Math.max(this.state.baseDelayMs, this.state.currentDelayMs * 0.9); + } + } + + if (this.shouldRetry && response.status >= 500) { + const errorBody = await response.clone().text(); + if (this.shouldRetry(response, errorBody)) { + await new Promise(r => setTimeout(r, this.retryDelayMs)); + const retryResponse = await doFetch(); + this.totalProcessed++; + request.resolve(retryResponse); + return; + } + } + + this.totalProcessed++; + request.resolve(response); + } catch (error) { + this.totalErrors++; + this.state.consecutiveErrors++; + request.reject(error instanceof Error ? error : new Error(String(error))); + } finally { + this.activeRequests--; + if (this.maxParallel > 1 && this.queue.length > 0) this.processQueue(); + } + } + + private async waitForNextSlot(): Promise { + const delayMs = this.calculateDelay(this.state); + this.state.currentDelayMs = delayMs; + if (delayMs <= 0) return; + const elapsed = Date.now() - this.lastRequestTime; + if (elapsed < delayMs) { + await new Promise(resolve => setTimeout(resolve, delayMs - elapsed)); + } + } +} diff --git a/packages/cli/src/handlers/shared/stream-parsers/gemini-sse.ts b/packages/cli/src/handlers/shared/stream-parsers/gemini-sse.ts index c5cc7ff6..31bd31a2 100644 --- a/packages/cli/src/handlers/shared/stream-parsers/gemini-sse.ts +++ b/packages/cli/src/handlers/shared/stream-parsers/gemini-sse.ts @@ -20,6 +20,8 @@ export interface GeminiSseOptions { onToolCall?: (toolId: string, name: string, thoughtSignature?: string) => void; /** CodeAssist wraps chunks in {response: {...}} */ unwrapResponse?: boolean; + /** Model dialect for additional text processing */ + modelDialect?: any; } export function createGeminiSseStream( @@ -196,7 +198,10 @@ export function createGeminiSseStream( let cleanedText = part.text; if (opts.adapter) { - const res = opts.adapter.processTextContent(part.text, accumulatedText); + let res = opts.adapter.processTextContent(part.text, accumulatedText); + if (opts.modelDialect?.processTextContent && !res.wasTransformed) { + res = opts.modelDialect.processTextContent(res.cleanedText, accumulatedText); + } cleanedText = res.cleanedText || ""; accumulatedText += cleanedText; } else { diff --git a/packages/cli/src/handlers/shared/stream-parsers/openai-sse.ts b/packages/cli/src/handlers/shared/stream-parsers/openai-sse.ts index 7a2f5ad1..b8834241 100644 --- a/packages/cli/src/handlers/shared/stream-parsers/openai-sse.ts +++ b/packages/cli/src/handlers/shared/stream-parsers/openai-sse.ts @@ -106,7 +106,8 @@ export function createStreamingResponseHandler( middlewareManager: any, onTokenUpdate?: (input: number, output: number) => void, toolSchemas?: any[], // Tool schemas for validation - toolNameMap?: Map // Truncated → original tool name mapping + toolNameMap?: Map, // Truncated → original tool name mapping + modelDialect?: any, // Model dialect for additional text processing ): Response { log(`[Streaming] ===== HANDLER STARTED for ${target} =====`); let isClosed = false; @@ -408,7 +409,10 @@ export function createStreamingResponseHandler( }); state.reasoningStarted = false; } - const res = adapter.processTextContent(txt, ""); + let res = adapter.processTextContent(txt, ""); + if (modelDialect?.processTextContent && !res.wasTransformed) { + res = modelDialect.processTextContent(res.cleanedText, ""); + } log( `[Streaming] After adapter: "${res.cleanedText.substring(0, 30).replace(/\n/g, "\\n")}" (${res.cleanedText.length} chars, transformed=${res.wasTransformed})` ); diff --git a/packages/cli/src/index.ts b/packages/cli/src/index.ts index 40d4ec15..0891a548 100644 --- a/packages/cli/src/index.ts +++ b/packages/cli/src/index.ts @@ -472,7 +472,7 @@ async function runCli() { // Run Claude Code with proxy let exitCode = 0; try { - exitCode = await runClaudeWithProxy(cliConfig, proxy.url, () => diag.cleanup(), mtmRunner); + exitCode = await runClaudeWithProxy(cliConfig, proxy.url, mtmRunner); } finally { // Clear diagOutput BEFORE cleanup to prevent write-after-end setDiagOutput(null); diff --git a/packages/cli/src/mcp-server.ts b/packages/cli/src/mcp-server.ts index 8899c106..389f4daa 100644 --- a/packages/cli/src/mcp-server.ts +++ b/packages/cli/src/mcp-server.ts @@ -3,7 +3,9 @@ /** * Claudish MCP Server * - * Exposes OpenRouter models as MCP tools for Claude Code. + * Exposes multi-provider model access as MCP tools for Claude Code. + * Supports provider@model syntax (e.g., g@gemini-2.0-flash, oai@gpt-5.1, gh@gpt-4o). + * Falls back to OpenRouter for unrecognized providers. * Run with: claudish-mcp (stdio transport) */ @@ -11,17 +13,17 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { z } from "zod"; import { config } from "dotenv"; -import { readFileSync, existsSync, writeFileSync, mkdirSync, readdirSync } from "node:fs"; +import { readFileSync, existsSync, writeFileSync, mkdirSync } from "node:fs"; import { join, dirname } from "node:path"; import { homedir } from "node:os"; import { fileURLToPath } from "node:url"; +import { parseModelSpec } from "./providers/model-parser.js"; +import { PROVIDERS_BY_NAME } from "./providers/provider-definitions.js"; import { - setupSession, - runModels, - judgeResponses, - getStatus, - validateSessionPath, -} from "./team-orchestrator.js"; + resolveEffective, + resolveTransport, + resolveAPIFormat, +} from "./providers/provider-profiles.js"; // Load environment variables config(); @@ -30,7 +32,7 @@ config(); const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); -// Paths - use ~/.claudish/ for writable cache (binaries can't write to __dirname) +// Paths const RECOMMENDED_MODELS_PATH = join(__dirname, "../recommended-models.json"); const CLAUDISH_CACHE_DIR = join(homedir(), ".claudish"); const ALL_MODELS_CACHE_PATH = join(CLAUDISH_CACHE_DIR, "all-models.json"); @@ -53,22 +55,6 @@ interface ModelInfo { supportsVision?: boolean; } -interface OpenRouterResponse { - id: string; - choices: Array<{ - message: { - content: string; - role: string; - }; - finish_reason: string; - }>; - usage?: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - }; -} - /** * Load recommended models from JSON */ @@ -111,7 +97,7 @@ async function loadAllModels(forceRefresh = false): Promise { const data = await response.json(); const models = data.data || []; - // Cache result - ensure directory exists + // Cache result mkdirSync(CLAUDISH_CACHE_DIR, { recursive: true }); writeFileSync( ALL_MODELS_CACHE_PATH, @@ -134,55 +120,126 @@ async function loadAllModels(forceRefresh = false): Promise { } /** - * Run a prompt through OpenRouter + * Resolve API key for a provider definition, checking primary env var and aliases. + */ +function resolveApiKey(def: { apiKeyEnvVar: string; apiKeyAliases?: string[] }): string { + if (def.apiKeyEnvVar && process.env[def.apiKeyEnvVar]) { + return process.env[def.apiKeyEnvVar]!; + } + if (def.apiKeyAliases) { + for (const alias of def.apiKeyAliases) { + if (process.env[alias]) return process.env[alias]!; + } + } + return ""; +} + +/** + * Run a prompt through the provider system. + * + * Parses model with provider@model syntax, resolves the provider definition, + * builds the request via the adapter, and fetches a non-streaming response. */ async function runPrompt( model: string, prompt: string, systemPrompt?: string, - maxTokens?: number + maxTokens?: number, + files?: string[], ): Promise<{ content: string; usage?: { input: number; output: number } }> { - const apiKey = process.env.OPENROUTER_API_KEY; - if (!apiKey) { - throw new Error("OPENROUTER_API_KEY environment variable not set"); + // Parse model spec (e.g., "g@gemini-2.0-flash", "openai/gpt-5.1") + const parsed = parseModelSpec(model); + + // Look up provider definition + let def = PROVIDERS_BY_NAME.get(parsed.provider); + + // Fall back to OpenRouter for unknown providers + if (!def) { + def = PROVIDERS_BY_NAME.get("openrouter"); + if (!def) { + throw new Error(`No provider found for "${parsed.provider}" and OpenRouter fallback unavailable`); + } + // For OpenRouter fallback, use the original model string as the model name + parsed.model = parsed.original; } - const messages: Array<{ role: string; content: string }> = []; + // Resolve effective definition (handles Zen minimax swap, publicKeyFallback, etc.) + const apiKey = resolveApiKey(def); + const effective = resolveEffective(def, parsed.model, apiKey); + + // Resolve transport and adapter + const transport = resolveTransport(effective.def, parsed.model, effective.apiKey); + if (!transport) { + throw new Error( + `Cannot create transport for provider "${effective.def.name}". ` + + `Check that ${effective.def.apiKeyEnvVar || "the required API key"} is set.` + ); + } + + const adapter = resolveAPIFormat(effective.def, parsed.model); + if (!adapter) { + throw new Error(`Cannot create format adapter for provider "${effective.def.name}"`); + } + + // Build a Claude-shaped request so the adapter can convert it + let fullPrompt = prompt; + if (files && files.length > 0) { + fullPrompt += "\n\n---\n"; + for (const filePath of files) { + try { + const content = readFileSync(filePath, "utf-8"); + fullPrompt += `\n### File: ${filePath}\n\`\`\`\n${content}\n\`\`\`\n`; + } catch (e) { + fullPrompt += `\n### File: ${filePath}\n[Error reading file: ${e instanceof Error ? e.message : String(e)}]\n`; + } + } + } + + const claudeRequest: any = { + model: parsed.model, + max_tokens: maxTokens || 4096, + messages: [{ role: "user", content: fullPrompt }], + }; if (systemPrompt) { - messages.push({ role: "system", content: systemPrompt }); + claudeRequest.system = systemPrompt; } - messages.push({ role: "user", content: prompt }); + // Convert through adapter + const messages = adapter.convertMessages(claudeRequest); + const tools: any[] = []; + const payload = adapter.buildPayload(claudeRequest, messages, tools); + + // Make it non-streaming + // Gemini uses a different endpoint for non-streaming + let endpoint: string; + if (transport.getNonStreamingEndpoint) { + endpoint = transport.getNonStreamingEndpoint(parsed.model); + } else { + // For OpenAI-compatible APIs, set stream=false and remove stream_options + payload.stream = false; + delete payload.stream_options; + endpoint = transport.getEndpoint(parsed.model); + } - const response = await fetch("https://openrouter.ai/api/v1/chat/completions", { + const headers = await transport.getHeaders(); + headers["Content-Type"] = "application/json"; + + const response = await fetch(endpoint, { method: "POST", - headers: { - Authorization: `Bearer ${apiKey}`, - "Content-Type": "application/json", - "HTTP-Referer": "https://claudish.com", - "X-Title": "Claudish MCP", - }, - body: JSON.stringify({ - model, - messages, - ...(maxTokens ? { max_tokens: maxTokens } : {}), - }), + headers, + body: JSON.stringify(payload), }); if (!response.ok) { - const error = await response.text(); - throw new Error(`OpenRouter API error: ${response.status} - ${error}`); + const errorText = await response.text(); + throw new Error( + `${effective.def.displayName} API error: ${response.status} - ${errorText}` + ); } - const data: OpenRouterResponse = await response.json(); - - const content = data.choices?.[0]?.message?.content || ""; - const usage = data.usage - ? { input: data.usage.prompt_tokens, output: data.usage.completion_tokens } - : undefined; - - return { content, usage }; + const data = await response.json(); + return adapter.parseResponse(data); } /** @@ -208,49 +265,6 @@ function fuzzyScore(text: string, query: string): number { return queryIndex === lowerQuery.length ? score / lowerText.length : 0; } -/** - * Format team run results with rich error context for failed models. - */ -function formatTeamResult( - status: import("./team-orchestrator.js").TeamStatus, - sessionPath: string -): string { - const entries = Object.entries(status.models); - const failed = entries.filter(([, m]) => m.state === "FAILED" || m.state === "TIMEOUT"); - const succeeded = entries.filter(([, m]) => m.state === "COMPLETED"); - - let result = JSON.stringify(status, null, 2); - - if (failed.length > 0) { - result += "\n\n---\n## Failures Detected\n\n"; - result += `${succeeded.length}/${entries.length} models succeeded, ${failed.length} failed.\n\n`; - - for (const [id, m] of failed) { - result += `### Model ${id}: ${m.state}\n`; - if (m.error) { - result += `- **Model:** ${m.error.model}\n`; - result += `- **Command:** \`${m.error.command}\`\n`; - result += `- **Exit code:** ${m.exitCode}\n`; - if (m.error.stderrSnippet) { - result += `- **Error output:**\n\`\`\`\n${m.error.stderrSnippet}\n\`\`\`\n`; - } - result += `- **Full error log:** ${m.error.errorLogPath}\n`; - result += `- **Working directory:** ${m.error.workDir}\n`; - } - result += "\n"; - } - - result += "---\n"; - result += "**To help claudish devs fix this**, use the `report_error` tool with:\n"; - result += '- `error_type`: "provider_failure" or "team_failure"\n'; - result += `- \`session_path\`: "${sessionPath}"\n`; - result += "- Copy the stderr snippet above into `stderr_snippet`\n"; - result += "- Set `auto_send: true` to suggest enabling automatic reporting\n"; - } - - return result; -} - /** * Create and start the MCP server */ @@ -260,433 +274,205 @@ async function main() { version: "2.5.0", }); - const toolMode = (process.env.CLAUDISH_MCP_TOOLS || "all").toLowerCase(); - const isLowLevel = toolMode === "all" || toolMode === "low-level"; - const isAgentic = toolMode === "all" || toolMode === "agentic"; - - console.error(`[claudish] MCP server started (tools: ${toolMode})`); - - if (isLowLevel) { - // Tool: run_prompt - Run a prompt through an OpenRouter model - server.tool( - "run_prompt", - "Run a prompt through an OpenRouter model (Grok, GPT-5, Gemini, etc.)", - { - model: z - .string() - .describe("OpenRouter model ID (e.g., 'x-ai/grok-code-fast-1', 'openai/gpt-5.1-codex')"), - prompt: z.string().describe("The prompt to send to the model"), - system_prompt: z.string().optional().describe("Optional system prompt"), - max_tokens: z - .number() - .optional() - .describe("Maximum tokens in response (omit to let model decide)"), - }, - async ({ model, prompt, system_prompt, max_tokens }) => { - try { - const result = await runPrompt(model, prompt, system_prompt, max_tokens); - - let response = result.content; - if (result.usage) { - response += `\n\n---\nTokens: ${result.usage.input} input, ${result.usage.output} output`; - } + console.error("[claudish] MCP server started"); + + // Tool: run_prompt - Run a prompt through any provider + server.tool( + "run_prompt", + "Run a prompt through a model using provider@model syntax (e.g., g@gemini-2.0-flash, oai@gpt-5.1, gh@gpt-4o, or@deepseek/deepseek-r1). Falls back to OpenRouter for unrecognized providers.", + { + model: z + .string() + .describe("Model in provider@model syntax (e.g., 'g@gemini-2.0-flash', 'oai@gpt-5.1', 'gh@gpt-4o') or OpenRouter model ID"), + prompt: z.string().describe("The prompt to send to the model"), + system_prompt: z.string().optional().describe("Optional system prompt"), + max_tokens: z + .number() + .optional() + .describe("Maximum tokens in response (default: 4096)"), + files: z + .array(z.string()) + .optional() + .describe("File paths to read and append after the prompt"), + }, + async ({ model, prompt, system_prompt, max_tokens, files }) => { + try { + const result = await runPrompt(model, prompt, system_prompt, max_tokens, files); - return { content: [{ type: "text", text: response }] }; - } catch (error) { - return { - content: [ - { - type: "text", - text: `Error: ${error instanceof Error ? error.message : String(error)}`, - }, - ], - isError: true, - }; + let response = result.content; + if (result.usage) { + response += `\n\n---\nTokens: ${result.usage.input} input, ${result.usage.output} output`; } - } - ); - - // Tool: list_models - List recommended models - server.tool("list_models", "List recommended models for coding tasks", {}, async () => { - const models = loadRecommendedModels(); - if (models.length === 0) { + return { content: [{ type: "text", text: response }] }; + } catch (error) { return { content: [ - { type: "text", text: "No recommended models found. Try search_models instead." }, + { + type: "text", + text: `Error: ${error instanceof Error ? error.message : String(error)}`, + }, ], + isError: true, }; } + } + ); + + // Tool: list_models - List recommended models + server.tool("list_models", "List recommended models for coding tasks", {}, async () => { + const models = loadRecommendedModels(); + + if (models.length === 0) { + return { + content: [ + { type: "text", text: "No recommended models found. Try search_models instead." }, + ], + }; + } - let output = "# Recommended Models\n\n"; - output += "| Model | Provider | Pricing | Context | Tools | Reasoning | Vision |\n"; - output += "|-------|----------|---------|---------|-------|-----------|--------|\n"; + let output = "# Recommended Models\n\n"; + output += "| Model | Provider | Pricing | Context | Tools | Reasoning | Vision |\n"; + output += "|-------|----------|---------|---------|-------|-----------|--------|\n"; - for (const model of models) { - const tools = model.supportsTools ? "✓" : "·"; - const reasoning = model.supportsReasoning ? "✓" : "·"; - const vision = model.supportsVision ? "✓" : "·"; - output += `| ${model.id} | ${model.provider} | ${model.pricing?.average || "N/A"} | ${model.context || "N/A"} | ${tools} | ${reasoning} | ${vision} |\n`; - } + for (const model of models) { + const tools = model.supportsTools ? "Y" : "."; + const reasoning = model.supportsReasoning ? "Y" : "."; + const vision = model.supportsVision ? "Y" : "."; + output += `| ${model.id} | ${model.provider} | ${model.pricing?.average || "N/A"} | ${model.context || "N/A"} | ${tools} | ${reasoning} | ${vision} |\n`; + } - output += "\n## Quick Picks\n"; - output += "- **Budget**: `minimax-m2.5` ($0.75/1M)\n"; - output += "- **Large context**: `gemini-3.1-pro-preview` (1M tokens)\n"; - output += "- **Most advanced**: `gpt-5.4` ($8.75/1M)\n"; - output += "- **Vision + coding**: `kimi-k2.5` ($1.32/1M)\n"; - output += "- **Agentic**: `glm-5` ($1.68/1M)\n"; - output += "- **Multimodal**: `qwen3.5-plus-02-15` ($1.40/1M)\n"; + output += "\n## Quick Picks\n"; + output += "- **Budget**: `minimax-m2.5` ($0.75/1M)\n"; + output += "- **Large context**: `gemini-3.1-pro-preview` (1M tokens)\n"; + output += "- **Most advanced**: `gpt-5.4` ($8.75/1M)\n"; + output += "- **Vision + coding**: `kimi-k2.5` ($1.32/1M)\n"; + output += "- **Agentic**: `glm-5` ($1.68/1M)\n"; + output += "- **Multimodal**: `qwen3.5-plus-02-15` ($1.40/1M)\n"; - return { content: [{ type: "text", text: output }] }; - }); - - // Tool: search_models - Search all OpenRouter models - server.tool( - "search_models", - "Search all OpenRouter models by name, provider, or capability", - { - query: z.string().describe("Search query (e.g., 'grok', 'vision', 'free')"), - limit: z.number().optional().describe("Maximum results to return (default: 10)"), - }, - async ({ query, limit }) => { - const maxResults = limit || 10; - const allModels = await loadAllModels(); - - if (allModels.length === 0) { - return { - content: [ - { type: "text", text: "Failed to load models. Check your internet connection." }, - ], - isError: true, - }; - } + return { content: [{ type: "text", text: output }] }; + }); - // Search with fuzzy matching - const results = allModels - .map((model) => { - const nameScore = fuzzyScore(model.name || "", query); - const idScore = fuzzyScore(model.id || "", query); - const descScore = fuzzyScore(model.description || "", query) * 0.5; - return { model, score: Math.max(nameScore, idScore, descScore) }; - }) - .filter((item) => item.score > 0.2) - .sort((a, b) => b.score - a.score) - .slice(0, maxResults); - - if (results.length === 0) { - return { - content: [{ type: "text", text: `No models found matching "${query}"` }], - }; - } + // Tool: search_models - Search all OpenRouter models + server.tool( + "search_models", + "Search all OpenRouter models by name, provider, or capability", + { + query: z.string().describe("Search query (e.g., 'grok', 'vision', 'free')"), + limit: z.number().optional().describe("Maximum results to return (default: 10)"), + }, + async ({ query, limit }) => { + const maxResults = limit || 10; + const allModels = await loadAllModels(); - let output = `# Search Results for "${query}"\n\n`; - output += "| Model | Provider | Pricing | Context |\n"; - output += "|-------|----------|---------|----------|\n"; - - for (const { model } of results) { - const provider = model.id.split("/")[0]; - const promptPrice = parseFloat(model.pricing?.prompt || "0") * 1000000; - const completionPrice = parseFloat(model.pricing?.completion || "0") * 1000000; - const avgPrice = (promptPrice + completionPrice) / 2; - const pricing = - avgPrice > 0 ? `$${avgPrice.toFixed(2)}/1M` : avgPrice < 0 ? "varies" : "FREE"; - const context = model.context_length - ? `${Math.round(model.context_length / 1000)}K` - : "N/A"; - - output += `| ${model.id} | ${provider} | ${pricing} | ${context} |\n`; - } + if (allModels.length === 0) { + return { + content: [ + { type: "text", text: "Failed to load models. Check your internet connection." }, + ], + isError: true, + }; + } - output += `\nUse with: run_prompt(model="${results[0].model.id}", prompt="your prompt")`; + // Search with fuzzy matching + const results = allModels + .map((model) => { + const nameScore = fuzzyScore(model.name || "", query); + const idScore = fuzzyScore(model.id || "", query); + const descScore = fuzzyScore(model.description || "", query) * 0.5; + return { model, score: Math.max(nameScore, idScore, descScore) }; + }) + .filter((item) => item.score > 0.2) + .sort((a, b) => b.score - a.score) + .slice(0, maxResults); + + if (results.length === 0) { + return { + content: [{ type: "text", text: `No models found matching "${query}"` }], + }; + } - return { content: [{ type: "text", text: output }] }; + let output = `# Search Results for "${query}"\n\n`; + output += "| Model | Provider | Pricing | Context |\n"; + output += "|-------|----------|---------|----------|\n"; + + for (const { model } of results) { + const provider = model.id.split("/")[0]; + const promptPrice = parseFloat(model.pricing?.prompt || "0") * 1000000; + const completionPrice = parseFloat(model.pricing?.completion || "0") * 1000000; + const avgPrice = (promptPrice + completionPrice) / 2; + const pricing = + avgPrice > 0 ? `$${avgPrice.toFixed(2)}/1M` : avgPrice < 0 ? "varies" : "FREE"; + const context = model.context_length + ? `${Math.round(model.context_length / 1000)}K` + : "N/A"; + + output += `| ${model.id} | ${provider} | ${pricing} | ${context} |\n`; } - ); - // Tool: compare_models - Run same prompt through multiple models - server.tool( - "compare_models", - "Run the same prompt through multiple models and compare responses", - { - models: z.array(z.string()).describe("List of model IDs to compare"), - prompt: z.string().describe("The prompt to send to all models"), - system_prompt: z.string().optional().describe("Optional system prompt"), - max_tokens: z - .number() - .optional() - .describe("Maximum tokens in response (omit to let model decide)"), - }, - async ({ models, prompt, system_prompt, max_tokens }) => { - const results: Array<{ - model: string; - response: string; - error?: string; - tokens?: { input: number; output: number }; - }> = []; - - for (const model of models) { - try { - const result = await runPrompt(model, prompt, system_prompt, max_tokens); - results.push({ - model, - response: result.content, - tokens: result.usage, - }); - } catch (error) { - results.push({ - model, - response: "", - error: error instanceof Error ? error.message : String(error), - }); - } - } + output += `\nUse with: run_prompt(model="${results[0].model.id}", prompt="your prompt")`; - let output = "# Model Comparison\n\n"; - output += `**Prompt:** ${prompt.slice(0, 100)}${prompt.length > 100 ? "..." : ""}\n\n`; - - for (const result of results) { - output += `## ${result.model}\n\n`; - if (result.error) { - output += `**Error:** ${result.error}\n\n`; - } else { - output += result.response + "\n\n"; - if (result.tokens) { - output += `*Tokens: ${result.tokens.input} in, ${result.tokens.output} out*\n\n`; - } - } - output += "---\n\n"; - } + return { content: [{ type: "text", text: output }] }; + } + ); + + // Tool: compare_models - Run same prompt through multiple models + server.tool( + "compare_models", + "Run the same prompt through multiple models and compare responses. Supports provider@model syntax.", + { + models: z.array(z.string()).describe("List of model IDs to compare (provider@model syntax supported)"), + prompt: z.string().describe("The prompt to send to all models"), + system_prompt: z.string().optional().describe("Optional system prompt"), + files: z + .array(z.string()) + .optional() + .describe("File paths to read and append after the prompt"), + }, + async ({ models, prompt, system_prompt, files }) => { + const results: Array<{ + model: string; + response: string; + error?: string; + tokens?: { input: number; output: number }; + }> = []; - return { content: [{ type: "text", text: output }] }; - } - ); - } // isLowLevel - - if (isAgentic) { - // Tool: team - Multi-model orchestration with anonymized blind evaluation - server.tool( - "team", - "Run AI models on a task with anonymized outputs and optional blind judging. Modes: 'run' (execute models), 'judge' (blind-vote on existing outputs), 'run-and-judge' (full pipeline), 'status' (check progress).", - { - mode: z.enum(["run", "judge", "run-and-judge", "status"]).describe("Operation mode"), - path: z - .string() - .describe("Session directory path (must be within current working directory)"), - models: z - .array(z.string()) - .optional() - .describe("Model IDs to run (required for 'run' and 'run-and-judge' modes)"), - judges: z - .array(z.string()) - .optional() - .describe("Model IDs to use as judges (default: same as runners)"), - input: z - .string() - .optional() - .describe("Task prompt text (or place input.md in the session directory before calling)"), - timeout: z.number().optional().describe("Per-model timeout in seconds (default: 300)"), - }, - async ({ mode, path, models, judges, input, timeout }) => { + for (const model of models) { try { - const resolved = validateSessionPath(path); - - switch (mode) { - case "run": { - if (!models?.length) throw new Error("'models' is required for 'run' mode"); - setupSession(resolved, models, input); - const status = await runModels(resolved, { timeout }); - return { content: [{ type: "text", text: formatTeamResult(status, resolved) }] }; - } - case "judge": { - const verdict = await judgeResponses(resolved, { judges }); - return { content: [{ type: "text", text: JSON.stringify(verdict, null, 2) }] }; - } - case "run-and-judge": { - if (!models?.length) throw new Error("'models' is required for 'run-and-judge' mode"); - setupSession(resolved, models, input); - await runModels(resolved, { timeout }); - const verdict = await judgeResponses(resolved, { judges }); - return { content: [{ type: "text", text: JSON.stringify(verdict, null, 2) }] }; - } - case "status": { - const status = getStatus(resolved); - return { content: [{ type: "text", text: JSON.stringify(status, null, 2) }] }; - } - } + const result = await runPrompt(model, prompt, system_prompt, 2048, files); + results.push({ + model, + response: result.content, + tokens: result.usage, + }); } catch (error) { - return { - content: [ - { - type: "text", - text: `Error: ${error instanceof Error ? error.message : String(error)}`, - }, - ], - isError: true, - }; + results.push({ + model, + response: "", + error: error instanceof Error ? error.message : String(error), + }); } } - ); - - // Tool: report_error - Send anonymized error data to claudish devs - server.tool( - "report_error", - "Report a claudish error to developers. IMPORTANT: Ask the user for consent BEFORE calling this tool. Show them what data will be sent (sanitized). All data is anonymized: API keys, user paths, and emails are stripped. Set auto_send=true to suggest the user enables automatic future reporting.", - { - error_type: z - .enum(["provider_failure", "team_failure", "stream_error", "adapter_error", "other"]) - .describe("Category of the error"), - model: z.string().optional().describe("Model ID that failed (anonymized in report)"), - command: z.string().optional().describe("Command that was run"), - stderr_snippet: z.string().optional().describe("First 500 chars of stderr output"), - exit_code: z.number().optional().describe("Process exit code"), - error_log_path: z.string().optional().describe("Path to full error log file"), - session_path: z.string().optional().describe("Path to team session directory"), - additional_context: z.string().optional().describe("Any extra context about the error"), - auto_send: z - .boolean() - .optional() - .describe("If true, suggest the user enable automatic error reporting"), - }, - async ({ - error_type, - model, - command, - stderr_snippet, - exit_code, - error_log_path, - session_path, - additional_context, - auto_send, - }) => { - // Sanitize: strip API keys, paths with usernames, env vars - function sanitize(text: string | undefined): string { - if (!text) return ""; - return text - .replace(/sk-[a-zA-Z0-9_-]{10,}/g, "sk-***REDACTED***") - .replace(/Bearer [a-zA-Z0-9_.-]+/g, "Bearer ***REDACTED***") - .replace(/\/Users\/[^/\s]+/g, "/Users/***") - .replace(/\/home\/[^/\s]+/g, "/home/***") - .replace(/[A-Z_]+_API_KEY=[^\s]+/g, "***_API_KEY=REDACTED") - .replace(/[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}/g, "***@***.***"); - } - - // Read stderr from log file — include full content, not just snippet - let stderrFull = stderr_snippet || ""; - if (error_log_path) { - try { - stderrFull = readFileSync(error_log_path, "utf-8"); - } catch { - // log file may not exist - } - } - - // Read session data if session_path provided (status, manifest, response outputs) - let sessionData: Record = {}; - if (session_path) { - const sp = session_path; - for (const file of ["status.json", "manifest.json", "input.md"]) { - try { - sessionData[file] = readFileSync(join(sp, file), "utf-8"); - } catch {} - } - try { - const errorDir = join(sp, "errors"); - if (existsSync(errorDir)) { - for (const f of readdirSync(errorDir)) { - if (f.endsWith(".log")) { - try { - sessionData[`errors/${f}`] = readFileSync(join(errorDir, f), "utf-8"); - } catch {} - } - } - } - } catch {} - try { - for (const f of readdirSync(sp)) { - if (f.startsWith("response-") && f.endsWith(".md")) { - try { - const content = readFileSync(join(sp, f), "utf-8"); - sessionData[f] = - content.slice(0, 200) + (content.length > 200 ? "... (truncated)" : ""); - } catch {} - } - } - } catch {} - } - // Read claudish version - let version = "unknown"; - try { - const pkgPath = join(__dirname, "../package.json"); - if (existsSync(pkgPath)) { - version = JSON.parse(readFileSync(pkgPath, "utf-8")).version; + let output = "# Model Comparison\n\n"; + output += `**Prompt:** ${prompt.slice(0, 100)}${prompt.length > 100 ? "..." : ""}\n\n`; + + for (const result of results) { + output += `## ${result.model}\n\n`; + if (result.error) { + output += `**Error:** ${result.error}\n\n`; + } else { + output += result.response + "\n\n"; + if (result.tokens) { + output += `*Tokens: ${result.tokens.input} in, ${result.tokens.output} out*\n\n`; } - } catch {} - - const report = { - version, - timestamp: new Date().toISOString(), - error_type, - model: model || "unknown", - command: sanitize(command), - stderr: sanitize(stderrFull), - exit_code: exit_code ?? null, - platform: process.platform, - arch: process.arch, - runtime: `bun ${process.version}`, - context: sanitize(additional_context), - session: Object.fromEntries( - Object.entries(sessionData).map(([k, v]) => [k, sanitize(v)]) - ), - }; - - const reportSummary = JSON.stringify(report, null, 2); - const autoSendHint = auto_send - ? "\n\n**Suggestion:** Enable automatic error reporting so future errors are sent without asking. Run `claudish config` → Privacy → toggle Telemetry, or set `CLAUDISH_TELEMETRY=1`." - : ""; - - const REPORT_URL = "https://api.claudish.com/v1/error-reports"; - - try { - const response = await fetch(REPORT_URL, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(report), - signal: AbortSignal.timeout(5000), - }); - - if (response.ok) { - return { - content: [ - { - type: "text", - text: `Error report sent successfully.\n\n**Sanitized data sent:**\n\`\`\`json\n${reportSummary}\n\`\`\`${autoSendHint}`, - }, - ], - }; - } else { - return { - content: [ - { - type: "text", - text: `Error report endpoint returned ${response.status}. Report was NOT sent.\n\n**Data that would have been sent (all sanitized):**\n\`\`\`json\n${reportSummary}\n\`\`\`\n\nYou can manually report this at https://github.com/anthropics/claudish/issues${autoSendHint}`, - }, - ], - }; - } - } catch (err) { - return { - content: [ - { - type: "text", - text: `Could not reach error reporting endpoint (${err instanceof Error ? err.message : "network error"}).\n\n**Sanitized error data (for manual reporting):**\n\`\`\`json\n${reportSummary}\n\`\`\`\n\nReport manually at https://github.com/anthropics/claudish/issues${autoSendHint}`, - }, - ], - }; } + output += "---\n\n"; } - ); - } // isAgentic + + return { content: [{ type: "text", text: output }] }; + } + ); // Start server with stdio transport const transport = new StdioServerTransport(); diff --git a/packages/cli/src/middleware/gemini-thought-signature.ts b/packages/cli/src/middleware/gemini-thought-signature.ts index e2e076d2..64a34f0a 100644 --- a/packages/cli/src/middleware/gemini-thought-signature.ts +++ b/packages/cli/src/middleware/gemini-thought-signature.ts @@ -46,10 +46,6 @@ export class GeminiThoughtSignatureMiddleware implements ModelMiddleware { } >(); - shouldHandle(modelId: string): boolean { - return modelId.includes("gemini") || modelId.includes("google/"); - } - onInit(): void { log("[Gemini] Thought signature middleware initialized"); } diff --git a/packages/cli/src/middleware/manager.ts b/packages/cli/src/middleware/manager.ts index 20ad4e8e..73d0526b 100644 --- a/packages/cli/src/middleware/manager.ts +++ b/packages/cli/src/middleware/manager.ts @@ -64,25 +64,17 @@ export class MiddlewareManager { } /** - * Get active middlewares for a specific model + * Get names of registered middlewares (all are active, selection is done by resolver). */ - private getActiveMiddlewares(modelId: string): ModelMiddleware[] { - return this.middlewares.filter((m) => m.shouldHandle(modelId)); - } - - /** - * Get names of active middlewares for a specific model. - * Used by stats recording to capture middleware names without details. - */ - getActiveNames(modelId: string): string[] { - return this.getActiveMiddlewares(modelId).map((m) => m.name); + getActiveNames(_modelId: string): string[] { + return this.middlewares.map((m) => m.name); } /** * Execute beforeRequest hooks for all active middlewares */ async beforeRequest(context: RequestContext): Promise { - const active = this.getActiveMiddlewares(context.modelId); + const active = this.middlewares; if (active.length === 0) { return; // No middlewares for this model @@ -110,7 +102,7 @@ export class MiddlewareManager { * Execute afterResponse hooks for non-streaming responses */ async afterResponse(context: NonStreamingResponseContext): Promise { - const active = this.getActiveMiddlewares(context.modelId); + const active = this.middlewares; if (active.length === 0) { return; @@ -138,7 +130,7 @@ export class MiddlewareManager { * Execute afterStreamChunk hooks for each streaming chunk */ async afterStreamChunk(context: StreamChunkContext): Promise { - const active = this.getActiveMiddlewares(context.modelId); + const active = this.middlewares; if (active.length === 0) { return; @@ -168,7 +160,7 @@ export class MiddlewareManager { * Execute afterStreamComplete hooks after streaming finishes */ async afterStreamComplete(modelId: string, metadata: Map): Promise { - const active = this.getActiveMiddlewares(modelId); + const active = this.middlewares; if (active.length === 0) { return; diff --git a/packages/cli/src/middleware/types.ts b/packages/cli/src/middleware/types.ts index 7a60f969..7eb6d18a 100644 --- a/packages/cli/src/middleware/types.ts +++ b/packages/cli/src/middleware/types.ts @@ -63,12 +63,6 @@ export interface ModelMiddleware { /** Unique name for this middleware (for logging) */ readonly name: string; - /** - * Determines if this middleware should handle the given model - * Called once per request to filter active middlewares - */ - shouldHandle(modelId: string): boolean; - /** * Called once when the proxy server starts (optional) * Use for initialization, loading config, etc. diff --git a/packages/cli/src/providers/auto-route.ts b/packages/cli/src/providers/auto-route.ts index 0e3393b1..91ad5c0d 100644 --- a/packages/cli/src/providers/auto-route.ts +++ b/packages/cli/src/providers/auto-route.ts @@ -236,7 +236,6 @@ export interface FallbackRoute { } import { - getShortestPrefix, getDisplayName as _getDisplayName, getAllProviders, } from "./provider-definitions.js"; diff --git a/packages/cli/src/providers/index.ts b/packages/cli/src/providers/index.ts index 67d49966..24280fc4 100644 --- a/packages/cli/src/providers/index.ts +++ b/packages/cli/src/providers/index.ts @@ -23,11 +23,12 @@ export { type UrlParsedModel, } from "./provider-registry.js"; -// Remote provider registry +// Provider definitions export { - resolveRemoteProvider, - getRegisteredRemoteProviders, -} from "./remote-provider-registry.js"; + getAllProviders, + toRemoteProvider, + getProviderByName, +} from "./provider-definitions.js"; // Model parser - unified syntax for provider@model[:concurrency] export { diff --git a/packages/cli/src/providers/model-parser.ts b/packages/cli/src/providers/model-parser.ts index 55cf8c97..949db29a 100644 --- a/packages/cli/src/providers/model-parser.ts +++ b/packages/cli/src/providers/model-parser.ts @@ -65,14 +65,14 @@ export interface ParsedModel { * Re-exported for backward compatibility. */ import { - getShortcuts as _getShortcuts, - getLegacyPrefixPatterns as _getLegacyPrefixPatterns, - getNativeModelPatterns as _getNativeModelPatterns, + PROVIDER_SHORTCUTS as _PROVIDER_SHORTCUTS, + LEGACY_PREFIX_PATTERNS as _LEGACY_PREFIX_PATTERNS, + NATIVE_MODEL_PATTERNS as _NATIVE_MODEL_PATTERNS, isLocalTransport, isDirectApiProvider as _isDirectApiProvider, } from "./provider-definitions.js"; -export const PROVIDER_SHORTCUTS: Record = _getShortcuts(); +export const PROVIDER_SHORTCUTS: Record = _PROVIDER_SHORTCUTS; /** * Local providers (no API key needed) — derived from BUILTIN_PROVIDERS. @@ -95,12 +95,12 @@ export const DIRECT_API_PROVIDERS = { /** * Native model prefixes — derived from BUILTIN_PROVIDERS. */ -export const NATIVE_MODEL_PATTERNS = _getNativeModelPatterns(); +export const NATIVE_MODEL_PATTERNS = _NATIVE_MODEL_PATTERNS; /** * Legacy prefix patterns — derived from BUILTIN_PROVIDERS. */ -export const LEGACY_PREFIX_PATTERNS = _getLegacyPrefixPatterns(); +export const LEGACY_PREFIX_PATTERNS = _LEGACY_PREFIX_PATTERNS; /** * Parse a model specification string diff --git a/packages/cli/src/providers/provider-definitions.test.ts b/packages/cli/src/providers/provider-definitions.test.ts index e0763d8a..ab635c7f 100644 --- a/packages/cli/src/providers/provider-definitions.test.ts +++ b/packages/cli/src/providers/provider-definitions.test.ts @@ -7,9 +7,9 @@ import { describe, test, expect } from "bun:test"; import { BUILTIN_PROVIDERS, - getShortcuts, - getLegacyPrefixPatterns, - getNativeModelPatterns, + PROVIDER_SHORTCUTS, + LEGACY_PREFIX_PATTERNS, + NATIVE_MODEL_PATTERNS, getProviderByName, getApiKeyInfo, getDisplayName, @@ -18,7 +18,6 @@ import { isDirectApiProvider, toRemoteProvider, getAllProviders, - getShortestPrefix, getApiKeyEnvVars, isProviderAvailable, type ProviderDefinition, @@ -91,11 +90,11 @@ describe("BUILTIN_PROVIDERS structural integrity", () => { }); // --------------------------------------------------------------------------- -// getShortcuts +// PROVIDER_SHORTCUTS // --------------------------------------------------------------------------- -describe("getShortcuts", () => { - const shortcuts = getShortcuts(); +describe("PROVIDER_SHORTCUTS", () => { + const shortcuts = PROVIDER_SHORTCUTS; test("maps 'g' to 'google'", () => { expect(shortcuts["g"]).toBe("google"); @@ -143,11 +142,11 @@ describe("getShortcuts", () => { }); // --------------------------------------------------------------------------- -// getLegacyPrefixPatterns +// LEGACY_PREFIX_PATTERNS // --------------------------------------------------------------------------- -describe("getLegacyPrefixPatterns", () => { - const patterns = getLegacyPrefixPatterns(); +describe("LEGACY_PREFIX_PATTERNS", () => { + const patterns = LEGACY_PREFIX_PATTERNS; test("includes 'g/' for google", () => { const gPattern = patterns.find((p) => p.prefix === "g/"); @@ -172,11 +171,11 @@ describe("getLegacyPrefixPatterns", () => { }); // --------------------------------------------------------------------------- -// getNativeModelPatterns +// NATIVE_MODEL_PATTERNS // --------------------------------------------------------------------------- -describe("getNativeModelPatterns", () => { - const patterns = getNativeModelPatterns(); +describe("NATIVE_MODEL_PATTERNS", () => { + const patterns = NATIVE_MODEL_PATTERNS; test("gemini-* matches google", () => { const match = patterns.find((p) => p.pattern.test("gemini-2.0-flash")); @@ -368,21 +367,9 @@ describe("toRemoteProvider", () => { }); // --------------------------------------------------------------------------- -// getShortestPrefix / getApiKeyEnvVars +// getApiKeyEnvVars // --------------------------------------------------------------------------- -describe("getShortestPrefix", () => { - test("returns shortest prefix for known providers", () => { - expect(getShortestPrefix("google")).toBe("g"); - expect(getShortestPrefix("minimax")).toBe("mm"); - expect(getShortestPrefix("openrouter")).toBe("or"); - }); - - test("falls back to provider name for unknown", () => { - expect(getShortestPrefix("unknown")).toBe("unknown"); - }); -}); - describe("getApiKeyEnvVars", () => { test("returns env var info for known providers", () => { const info = getApiKeyEnvVars("google"); diff --git a/packages/cli/src/providers/provider-definitions.ts b/packages/cli/src/providers/provider-definitions.ts index 68317685..364342ff 100644 --- a/packages/cli/src/providers/provider-definitions.ts +++ b/packages/cli/src/providers/provider-definitions.ts @@ -10,9 +10,9 @@ */ import type { RemoteProvider } from "../handlers/shared/remote-provider-types.js"; -import { existsSync } from "node:fs"; -import { join } from "node:path"; +import { readFileSync, existsSync } from "node:fs"; import { homedir } from "node:os"; +import { join } from "node:path"; // --------------------------------------------------------------------------- // Types @@ -83,6 +83,9 @@ export interface ProviderDefinition { oauthFallback?: string; /** Whether this is a local provider (no API key needed) */ isLocal?: boolean; + /** Whether model names for this provider may contain vendor prefixes (e.g., "google/gemini-2.0-flash") + * that should be preserved during catalog resolution. Used by aggregators like OpenRouter and LiteLLM. */ + preserveVendorPrefix?: boolean; /** Whether this provider supports direct API access (not just via OpenRouter) */ isDirectApi?: boolean; /** Shortest @ prefix for handler creation (reverse of shortcuts) */ @@ -161,6 +164,23 @@ export const BUILTIN_PROVIDERS: ProviderDefinition[] = [ description: "Direct OpenAI API (oai@)", }, + // ── GitHub Models (OpenAI-compatible) ──────────────────────────────── + { + name: "github-models", + displayName: "GitHub Models", + transport: "openai", + baseUrl: "https://models.github.ai/inference", + apiPath: "/chat/completions", + apiKeyEnvVar: "GITHUB_MODELS_TOKEN", + apiKeyDescription: "GitHub Fine-Grained PAT", + apiKeyUrl: "https://github.com/settings/tokens", + shortcuts: ["gh"], + shortestPrefix: "gh", + legacyPrefixes: [{ prefix: "gh/", stripPrefix: true }], + isDirectApi: true, + description: "GitHub Models (gh@)", + }, + // ── OpenRouter ───────────────────────────────────────────────────── { name: "openrouter", @@ -180,6 +200,7 @@ export const BUILTIN_PROVIDERS: ProviderDefinition[] = [ "X-Title": "Claudish - OpenRouter Proxy", }, isDirectApi: true, + preserveVendorPrefix: true, description: "580+ models, default backend (or@)", }, @@ -408,6 +429,21 @@ export const BUILTIN_PROVIDERS: ProviderDefinition[] = [ description: "OpenCode Zen (zen@) - free models", }, + // ── OpenCode Zen MiniMax (internal variant, swapped by resolveEffective) ── + { + name: "opencode-zen-minimax", + displayName: "OpenCode Zen (MiniMax)", + shortcuts: [], + legacyPrefixes: [], + baseUrl: process.env.OPENCODE_BASE_URL || "https://opencode.ai/zen", + apiPath: "/v1/messages", + apiKeyEnvVar: "OPENCODE_API_KEY", + apiKeyDescription: "OpenCode API Key", + apiKeyUrl: "https://opencode.ai", + transport: "anthropic", + publicKeyFallback: "public", + }, + // ── OpenCode Zen Go (lite plan) ──────────────────────────────────── { name: "opencode-zen-go", @@ -430,6 +466,21 @@ export const BUILTIN_PROVIDERS: ProviderDefinition[] = [ description: "OpenCode Zen Go plan (zengo@)", }, + // ── OpenCode Zen Go MiniMax (internal variant, swapped by resolveEffective) ── + { + name: "opencode-zen-go-minimax", + displayName: "OpenCode Zen Go (MiniMax)", + shortcuts: [], + legacyPrefixes: [], + baseUrl: process.env.OPENCODE_BASE_URL ? process.env.OPENCODE_BASE_URL.replace("/zen", "/zen/go") : "https://opencode.ai/zen/go", + apiPath: "/v1/messages", + apiKeyEnvVar: "OPENCODE_API_KEY", + apiKeyDescription: "OpenCode API Key", + apiKeyUrl: "https://opencode.ai", + transport: "anthropic", + publicKeyFallback: "public", + }, + // ── Vertex AI ────────────────────────────────────────────────────── { name: "vertex", @@ -469,6 +520,7 @@ export const BUILTIN_PROVIDERS: ProviderDefinition[] = [ { prefix: "ll/", stripPrefix: true }, ], isDirectApi: true, + preserveVendorPrefix: true, description: "LiteLLM proxy (ll@, litellm@)", }, @@ -608,100 +660,122 @@ export const BUILTIN_PROVIDERS: ProviderDefinition[] = [ ]; // --------------------------------------------------------------------------- -// Lazy-cached derived accessors +// Pre-computed derived maps (built once at module load) // --------------------------------------------------------------------------- -let _shortcutsCache: Record | null = null; -let _legacyPrefixCache: Array<{ - prefix: string; - provider: string; - stripPrefix: boolean; -}> | null = null; -let _nativeModelPatternsCache: Array<{ pattern: RegExp; provider: string }> | null = null; -let _providerByNameCache: Map | null = null; -let _directApiProvidersCache: Set | null = null; -let _localProvidersCache: Set | null = null; - -function ensureProviderByNameCache(): Map { - if (!_providerByNameCache) { - _providerByNameCache = new Map(); - for (const def of BUILTIN_PROVIDERS) { - _providerByNameCache.set(def.name, def); - } - } - return _providerByNameCache; -} +/** Provider name -> definition lookup. */ +export const PROVIDERS_BY_NAME: Map = (() => { + const m = new Map(); + for (const def of BUILTIN_PROVIDERS) m.set(def.name, def); + return m; +})(); -/** - * Get the shortcuts → canonical provider name mapping. - * Replaces PROVIDER_SHORTCUTS in model-parser.ts. - */ -export function getShortcuts(): Record { - if (!_shortcutsCache) { - _shortcutsCache = {}; - for (const def of BUILTIN_PROVIDERS) { - for (const shortcut of def.shortcuts) { - _shortcutsCache[shortcut] = def.name; - } - } +/** Shortcut -> canonical provider name. */ +export const PROVIDER_SHORTCUTS: Record = (() => { + const r: Record = {}; + for (const def of BUILTIN_PROVIDERS) { + for (const s of def.shortcuts) r[s] = def.name; } - return _shortcutsCache; -} + return r; +})(); -/** - * Get legacy prefix patterns for backwards compatibility. - * Replaces LEGACY_PREFIX_PATTERNS in model-parser.ts. - */ -export function getLegacyPrefixPatterns(): Array<{ +/** Legacy prefix patterns for backwards compatibility. */ +export const LEGACY_PREFIX_PATTERNS: Array<{ prefix: string; provider: string; stripPrefix: boolean; -}> { - if (!_legacyPrefixCache) { - _legacyPrefixCache = []; - for (const def of BUILTIN_PROVIDERS) { - for (const lp of def.legacyPrefixes) { - _legacyPrefixCache.push({ - prefix: lp.prefix, - provider: def.name, - stripPrefix: lp.stripPrefix, - }); - } +}> = (() => { + const a: Array<{ prefix: string; provider: string; stripPrefix: boolean }> = []; + for (const def of BUILTIN_PROVIDERS) { + for (const lp of def.legacyPrefixes) { + a.push({ prefix: lp.prefix, provider: def.name, stripPrefix: lp.stripPrefix }); } } - return _legacyPrefixCache; -} + return a; +})(); /** - * Get native model patterns for auto-detection. - * Replaces NATIVE_MODEL_PATTERNS in model-parser.ts. - * - * Order follows the definition order in BUILTIN_PROVIDERS. + * Native model patterns for auto-detection. + * Order follows BUILTIN_PROVIDERS definition order. * kimi-coding's pattern (kimi-for-coding$) comes before kimi's (kimi-*) because * kimi-coding is defined earlier in BUILTIN_PROVIDERS. */ -export function getNativeModelPatterns(): Array<{ pattern: RegExp; provider: string }> { - if (!_nativeModelPatternsCache) { - _nativeModelPatternsCache = []; - for (const def of BUILTIN_PROVIDERS) { - if (def.nativeModelPatterns) { - for (const np of def.nativeModelPatterns) { - _nativeModelPatternsCache.push({ - pattern: np.pattern, - provider: def.name, - }); - } +export const NATIVE_MODEL_PATTERNS: Array<{ pattern: RegExp; provider: string }> = (() => { + const a: Array<{ pattern: RegExp; provider: string }> = []; + for (const def of BUILTIN_PROVIDERS) { + if (def.nativeModelPatterns) { + for (const np of def.nativeModelPatterns) { + a.push({ pattern: np.pattern, provider: def.name }); } } } - return _nativeModelPatternsCache; + return a; +})(); + +/** API key env var + metadata per provider. */ +export const API_KEY_INFO: Record = (() => { + const r: Record = {}; + for (const def of BUILTIN_PROVIDERS) { + r[def.name] = { + envVar: def.apiKeyEnvVar, + description: def.apiKeyDescription, + url: def.apiKeyUrl, + aliases: def.apiKeyAliases, + oauthFallback: def.oauthFallback, + }; + } + return r; +})(); + +/** Display name per provider. */ +export const PROVIDER_DISPLAY_NAMES: Record = (() => { + const r: Record = {}; + for (const def of BUILTIN_PROVIDERS) r[def.name] = def.displayName; + return r; +})(); + +/** Set of local provider names. */ +const LOCAL_PROVIDERS: Set = (() => { + const s = new Set(); + for (const def of BUILTIN_PROVIDERS) { + if (def.isLocal) s.add(def.name); + } + return s; +})(); + +/** Set of direct API provider names. */ +const DIRECT_API_PROVIDERS: Set = (() => { + const s = new Set(); + for (const def of BUILTIN_PROVIDERS) { + if (def.isDirectApi) s.add(def.name); + } + return s; +})(); + +// --------------------------------------------------------------------------- +// Getter functions +// --------------------------------------------------------------------------- + +/** Get a provider definition by canonical name. */ +export function getProviderByName(name: string): ProviderDefinition | undefined { + return PROVIDERS_BY_NAME.get(name); } /** - * Get a provider definition by canonical name. + * Get a provider definition by RemoteProvider name (handles google/gemini reverse mapping). + * RemoteProvider.name uses "gemini" for the google provider; this looks up the original definition. */ -export function getProviderByName(name: string): ProviderDefinition | undefined { - return ensureProviderByNameCache().get(name); +export function getProviderDefinitionByRemoteName(remoteName: string): ProviderDefinition | undefined { + const direct = PROVIDERS_BY_NAME.get(remoteName); + if (direct) return direct; + if (remoteName === "gemini") return PROVIDERS_BY_NAME.get("google"); + return undefined; } /** @@ -753,15 +827,7 @@ export function getEffectiveBaseUrl(def: ProviderDefinition): string { * Replaces LOCAL_PROVIDERS set in model-parser.ts. */ export function isLocalTransport(providerName: string): boolean { - if (!_localProvidersCache) { - _localProvidersCache = new Set(); - for (const def of BUILTIN_PROVIDERS) { - if (def.isLocal) { - _localProvidersCache.add(def.name); - } - } - } - return _localProvidersCache.has(providerName.toLowerCase()); + return LOCAL_PROVIDERS.has(providerName.toLowerCase()); } /** @@ -769,15 +835,7 @@ export function isLocalTransport(providerName: string): boolean { * Replaces DIRECT_API_PROVIDERS set in model-parser.ts. */ export function isDirectApiProvider(providerName: string): boolean { - if (!_directApiProvidersCache) { - _directApiProvidersCache = new Set(); - for (const def of BUILTIN_PROVIDERS) { - if (def.isDirectApi) { - _directApiProvidersCache.add(def.name); - } - } - } - return _directApiProvidersCache.has(providerName.toLowerCase()); + return DIRECT_API_PROVIDERS.has(providerName.toLowerCase()); } /** @@ -813,15 +871,6 @@ export function getAllProviders(): ProviderDefinition[] { return BUILTIN_PROVIDERS; } -/** - * Get the shortest prefix for a provider (for @ syntax handler creation). - * Replaces PROVIDER_TO_PREFIX in auto-route.ts. - */ -export function getShortestPrefix(providerName: string): string { - const def = getProviderByName(providerName); - return def?.shortestPrefix || providerName; -} - /** * Get API key env var info for a provider (for auto-route). * Replaces API_KEY_ENV_VARS in auto-route.ts. @@ -889,3 +938,124 @@ export function isProviderAvailableByName(providerName: string): boolean { if (!def) return false; return isProviderAvailable(def); } + +// --------------------------------------------------------------------------- +// Dynamic provider registration +// --------------------------------------------------------------------------- + +/** + * Register a provider definition into all derived maps. + * Used for user-defined providers loaded from config. + */ +export function registerProvider(def: ProviderDefinition): void { + BUILTIN_PROVIDERS.push(def); + PROVIDERS_BY_NAME.set(def.name, def); + + for (const s of def.shortcuts) { + PROVIDER_SHORTCUTS[s] = def.name; + } + + for (const lp of def.legacyPrefixes) { + LEGACY_PREFIX_PATTERNS.push({ + prefix: lp.prefix, + provider: def.name, + stripPrefix: lp.stripPrefix, + }); + } + + if (def.nativeModelPatterns) { + for (const np of def.nativeModelPatterns) { + NATIVE_MODEL_PATTERNS.push({ pattern: np.pattern, provider: def.name }); + } + } + + API_KEY_INFO[def.name] = { + envVar: def.apiKeyEnvVar, + description: def.apiKeyDescription, + url: def.apiKeyUrl, + aliases: def.apiKeyAliases, + oauthFallback: def.oauthFallback, + }; + + PROVIDER_DISPLAY_NAMES[def.name] = def.displayName; + + if (def.isLocal) { + LOCAL_PROVIDERS.add(def.name); + } + + if (def.isDirectApi) { + DIRECT_API_PROVIDERS.add(def.name); + } +} + +// --------------------------------------------------------------------------- +// User-defined providers (loaded from ~/.claudish/config.json) +// --------------------------------------------------------------------------- + +/** Shape of a user-defined provider in config.json */ +export interface UserProviderConfig { + name: string; + displayName?: string; + transport: TransportType; + baseUrl: string; + apiPath?: string; + apiKeyEnvVar?: string; + apiKeyDescription?: string; + apiKeyUrl?: string; + shortcuts?: string[]; + shortestPrefix?: string; + legacyPrefixes?: Array<{ prefix: string; stripPrefix: boolean }>; + nativeModelPatterns?: string[]; + isDirectApi?: boolean; + isLocal?: boolean; + description?: string; + headers?: Record; +} + +/** + * Load user-defined providers from ~/.claudish/config.json and register them. + */ +export function loadUserProviders(): void { + const configPath = join(homedir(), ".claudish", "config.json"); + if (!existsSync(configPath)) return; + + try { + const raw = readFileSync(configPath, "utf-8"); + const config = JSON.parse(raw); + const providers: UserProviderConfig[] = config.providers; + if (!Array.isArray(providers)) return; + + for (const p of providers) { + if (!p.name || !p.transport || !p.baseUrl) continue; + if (PROVIDERS_BY_NAME.has(p.name)) continue; // do not override builtins + + const def: ProviderDefinition = { + name: p.name, + displayName: p.displayName || p.name, + transport: p.transport, + baseUrl: p.baseUrl, + apiPath: p.apiPath || "/v1/chat/completions", + apiKeyEnvVar: p.apiKeyEnvVar || "", + apiKeyDescription: p.apiKeyDescription || `${p.displayName || p.name} API Key`, + apiKeyUrl: p.apiKeyUrl || "", + shortcuts: p.shortcuts || [], + shortestPrefix: p.shortestPrefix || p.name, + legacyPrefixes: p.legacyPrefixes || [], + nativeModelPatterns: p.nativeModelPatterns + ? p.nativeModelPatterns.map((pat) => ({ pattern: new RegExp(pat, "i") })) + : undefined, + isDirectApi: p.isDirectApi, + isLocal: p.isLocal, + description: p.description, + headers: p.headers, + }; + + registerProvider(def); + } + } catch { + // Config parse error; silently skip user providers + } +} + +// Load user providers at module init +loadUserProviders(); diff --git a/packages/cli/src/providers/provider-profiles.ts b/packages/cli/src/providers/provider-profiles.ts index dfe2562e..6a9bb662 100644 --- a/packages/cli/src/providers/provider-profiles.ts +++ b/packages/cli/src/providers/provider-profiles.ts @@ -1,31 +1,27 @@ /** - * ProviderProfile — declares how to construct a ComposedHandler for a specific remote provider. + * createHandlerForProvider — construct a ComposedHandler for a ProviderDefinition. * - * Maps provider name → transport class + adapter class + handler options. - * Replaces the 250-line if/else chain in proxy-server.ts with a data-driven table. - * - * Design rules: - * - Exact behaviour match — every profile must produce the same transport+adapter+options as the - * original if/else branch. No behaviour changes. - * - Special cases (opencode-zen, vertex) keep their branching logic inside the profile's factory - * methods rather than cluttering the lookup code. - * - Resolution (looking up the profile and calling createHandlerForProvider) happens in - * proxy-server.ts. Profiles do not know about caching or invocationMode. + * Single function, flat switch on def.transport. All transport + format + handler + * construction in one place. No profile objects, no lookup tables, no indirection. */ import type { ComposedHandlerOptions } from "../handlers/composed-handler.js"; -import type { RemoteProvider } from "../handlers/shared/remote-provider-types.js"; import type { ProviderTransport } from "./transport/types.js"; import type { BaseAPIFormat } from "../adapters/base-api-format.js"; -// Alias for readability within this file -type BaseModelAdapter = BaseAPIFormat; +import type { BaseModelDialect } from "../adapters/base-model-dialect.js"; +import type { ProviderDefinition } from "./provider-definitions.js"; +import type { ModelHandler } from "../handlers/types.js"; import { ComposedHandler } from "../handlers/composed-handler.js"; +import { GeminiThoughtSignatureMiddleware } from "../middleware/gemini-thought-signature.js"; +import type { ModelMiddleware } from "../middleware/types.js"; +import { toRemoteProvider } from "./provider-definitions.js"; import { GeminiProviderTransport } from "./transport/gemini-apikey.js"; import { GeminiCodeAssistProviderTransport } from "./transport/gemini-codeassist.js"; import { GeminiAPIFormat } from "../adapters/gemini-api-format.js"; import { OpenAIProviderTransport } from "./transport/openai.js"; import { OpenAIAPIFormat } from "../adapters/openai-api-format.js"; import { AnthropicProviderTransport } from "./transport/anthropic-compat.js"; +import { KimiCodingTransport } from "./transport/kimi-coding.js"; import { AnthropicAPIFormat } from "../adapters/anthropic-api-format.js"; import { OllamaProviderTransport } from "./transport/ollamacloud.js"; import { OllamaAPIFormat } from "../adapters/ollama-api-format.js"; @@ -33,343 +29,402 @@ import { LiteLLMProviderTransport } from "./transport/litellm.js"; import { LiteLLMAPIFormat } from "../adapters/litellm-api-format.js"; import { CodexAPIFormat } from "../adapters/codex-api-format.js"; import { VertexProviderTransport, parseVertexModel } from "./transport/vertex-oauth.js"; -import { DefaultAPIFormat } from "../adapters/base-api-format.js"; -import { OpenRouterProvider } from "./transport/openrouter.js"; -import { getRegisteredRemoteProviders } from "./remote-provider-registry.js"; +import { DefaultAPIFormat, matchesModelFamily } from "../adapters/base-api-format.js"; +import { GrokModelDialect } from "../adapters/grok-model-dialect.js"; +import { GeminiModelDialect } from "../adapters/gemini-model-dialect.js"; +import { QwenModelDialect } from "../adapters/qwen-model-dialect.js"; +import { MiniMaxModelDialect } from "../adapters/minimax-model-dialect.js"; +import { DeepSeekModelDialect } from "../adapters/deepseek-model-dialect.js"; +import { GLMModelDialect } from "../adapters/glm-model-dialect.js"; +import { XiaomiModelDialect } from "../adapters/xiaomi-model-dialect.js"; +import { OpenRouterProviderTransport } from "./transport/openrouter.js"; +import { OpenRouterAPIFormat } from "../adapters/openrouter-api-format.js"; +import { ZenTransport } from "./transport/zen.js"; +import { PoeProvider } from "./transport/poe.js"; +import { LocalTransport } from "./transport/local.js"; +import { OllamaTransport } from "./transport/ollama.js"; +import { LMStudioTransport } from "./transport/lmstudio.js"; +import { LocalModelAdapter } from "../adapters/local-adapter.js"; +import { LocalQwenFormatAdapter } from "../adapters/local-qwen-format-adapter.js"; +import { LocalDeepSeekFormatAdapter } from "../adapters/local-deepseek-format-adapter.js"; +import { LocalLlamaFormatAdapter } from "../adapters/local-llama-format-adapter.js"; +import { LocalMistralFormatAdapter } from "../adapters/local-mistral-format-adapter.js"; +import { + resolveProvider, + parseUrlModel, + createUrlProvider, + type LocalProvider as LocalProviderConfig, +} from "./provider-registry.js"; +import { getProviderByName } from "./provider-definitions.js"; import { getVertexConfig, validateVertexOAuthConfig } from "../auth/vertex-auth.js"; import { log, logStderr } from "../logger.js"; import { resolveApiKeyProvenance, formatProvenanceLog } from "./api-key-provenance.js"; -import type { ModelHandler } from "../handlers/types.js"; // --------------------------------------------------------------------------- -// Types +// Shared helpers // --------------------------------------------------------------------------- -/** - * Context passed to profile factory methods at handler-creation time. - * All values come from the already-resolved provider and the outer createProxyServer closure. - */ -export interface ProfileContext { - /** The resolved RemoteProvider config (baseUrl, headers, authScheme, etc.) */ - provider: RemoteProvider; - /** The model name after stripping the provider prefix (e.g. "gemini-2.5-flash") */ - modelName: string; - /** The API key resolved from env (empty string for auth-less providers) */ - apiKey: string; - /** The original targetModel string passed by the caller */ - targetModel: string; - /** The listening port of the proxy server */ - port: number; - /** Shared ComposedHandler options from the outer scope */ - sharedOpts: Pick; -} - -/** - * ProviderProfile — describes how to construct a ModelHandler for a provider. - * - * The simplest profiles just implement createHandler() and log a message. - * Complex ones (opencode-zen, vertex) may contain branching logic internally. - */ -export interface ProviderProfile { - /** - * Attempt to create a ModelHandler for this provider. - * - * Returns null if the provider config is invalid (e.g. missing LITELLM_BASE_URL). - * Returning null causes proxy-server.ts to skip caching and fall through. - */ - createHandler(ctx: ProfileContext): ModelHandler | null; +function isZenProvider(def: ProviderDefinition): boolean { + return def.name === "opencode-zen" || def.name === "opencode-zen-go"; } // --------------------------------------------------------------------------- -// Profile implementations +// Transport resolution // --------------------------------------------------------------------------- -const geminiProfile: ProviderProfile = { - createHandler(ctx) { - const transport = new GeminiProviderTransport(ctx.provider, ctx.modelName, ctx.apiKey); - const adapter = new GeminiAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - ...ctx.sharedOpts, - }); - log(`[Proxy] Created Gemini handler (composed): ${ctx.modelName}`); - return handler; - }, -}; - -const geminiCodeAssistProfile: ProviderProfile = { - createHandler(ctx) { - const transport = new GeminiCodeAssistProviderTransport(ctx.modelName); - const adapter = new GeminiAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - unwrapGeminiResponse: true, - ...ctx.sharedOpts, - }); - log(`[Proxy] Created Gemini Code Assist handler (composed): ${ctx.modelName}`); - return handler; - }, -}; - -const openaiProfile: ProviderProfile = { - createHandler(ctx) { - const transport = new OpenAIProviderTransport(ctx.provider, ctx.modelName, ctx.apiKey); - const adapter = new OpenAIAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - tokenStrategy: "delta-aware", - ...ctx.sharedOpts, - }); - log(`[Proxy] Created OpenAI handler (composed): ${ctx.modelName}`); - return handler; - }, -}; - -/** Shared profile for MiniMax, Kimi, Kimi Coding, and Z.AI (all Anthropic-compatible APIs) */ -const anthropicCompatProfile: ProviderProfile = { - createHandler(ctx) { - const transport = new AnthropicProviderTransport(ctx.provider, ctx.apiKey); - const adapter = new AnthropicAPIFormat(ctx.modelName, ctx.provider.name); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - ...ctx.sharedOpts, - }); - log(`[Proxy] Created ${ctx.provider.name} handler (composed): ${ctx.modelName}`); - return handler; - }, -}; - -/** GLM and GLM Coding Plan use the OpenAI-compatible API */ -const glmProfile: ProviderProfile = { - createHandler(ctx) { - const transport = new OpenAIProviderTransport(ctx.provider, ctx.modelName, ctx.apiKey); - const adapter = new OpenAIAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - tokenStrategy: "delta-aware", - ...ctx.sharedOpts, - }); - log(`[Proxy] Created ${ctx.provider.name} handler (composed): ${ctx.modelName}`); - return handler; - }, -}; +export function resolveTransport( + def: ProviderDefinition, + modelName: string, + apiKey: string, +): ProviderTransport | null { + const rp = toRemoteProvider(def); -/** - * OpenCode Zen / Zen Go — two tiers: - * zen/ (opencode-zen): free anonymous models + full paid access (OPENCODE_API_KEY) - * zgo/ (opencode-zen-go): go-plan models (glm-5, minimax-m2.5, kimi-k2.5) via zen/go/v1/ - * - * Free anonymous models work without a key; uses "public" as fallback for consistent - * rate-limit bucketing. - * - * Model routing inside the profile: - * - MiniMax models → AnthropicProviderTransport + AnthropicAPIFormat - * - GPT-* models → OpenAIProviderTransport (/v1/responses) + CodexAPIFormat (Responses API) - * - All other models → OpenAIProviderTransport (/v1/chat/completions) + OpenAIAPIFormat (delta-aware) - */ -const openCodeZenProfile: ProviderProfile = { - createHandler(ctx) { - const zenApiKey = ctx.apiKey || "public"; - const isGoProvider = ctx.provider.name === "opencode-zen-go"; - - if (ctx.modelName.toLowerCase().includes("minimax")) { - const transport = new AnthropicProviderTransport(ctx.provider, zenApiKey); - const adapter = new AnthropicAPIFormat(ctx.modelName, ctx.provider.name); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - ...ctx.sharedOpts, - }); - log( - `[Proxy] Created OpenCode Zen${isGoProvider ? " Go" : ""} (Anthropic composed): ${ctx.modelName}` - ); - return handler; - } + switch (def.transport) { + case "gemini": + return new GeminiProviderTransport(rp, modelName, apiKey); - // GPT models are served via the OpenAI Responses API (/v1/responses), not /v1/chat/completions. - if (ctx.modelName.toLowerCase().startsWith("gpt-")) { - const responsesProvider = { ...ctx.provider, apiPath: "/v1/responses" }; - const transport = new OpenAIProviderTransport(responsesProvider, ctx.modelName, zenApiKey); - const adapter = new CodexAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - tokenStrategy: "delta-aware", - ...ctx.sharedOpts, - }); - log( - `[Proxy] Created OpenCode Zen${isGoProvider ? " Go" : ""} (Responses API composed): ${ctx.modelName}` - ); - return handler; - } + case "gemini-oauth": + return new GeminiCodeAssistProviderTransport(modelName); - const transport = new OpenAIProviderTransport(ctx.provider, ctx.modelName, zenApiKey); - const adapter = new OpenAIAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - tokenStrategy: "delta-aware", - ...ctx.sharedOpts, - }); - log(`[Proxy] Created OpenCode Zen${isGoProvider ? " Go" : ""} (composed): ${ctx.modelName}`); - return handler; - }, -}; - -const ollamaCloudProfile: ProviderProfile = { - createHandler(ctx) { - const transport = new OllamaProviderTransport(ctx.provider, ctx.apiKey); - const adapter = new OllamaAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - tokenStrategy: "accumulate-both", - ...ctx.sharedOpts, - }); - log(`[Proxy] Created OllamaCloud handler (composed): ${ctx.modelName}`); - return handler; - }, -}; - -const litellmProfile: ProviderProfile = { - createHandler(ctx) { - if (!ctx.provider.baseUrl) { - logStderr("Error: LITELLM_BASE_URL or --litellm-url is required for LiteLLM provider."); - logStderr("Set it with: export LITELLM_BASE_URL='https://your-litellm-instance.com'"); - logStderr( - "Or use: claudish --litellm-url https://your-instance.com --model litellm@model 'task'" - ); - return null; + case "openai": { + if (isZenProvider(def)) return new ZenTransport(rp, modelName, apiKey); + return new OpenAIProviderTransport(rp, modelName, apiKey); } - const transport = new LiteLLMProviderTransport(ctx.provider.baseUrl, ctx.apiKey, ctx.modelName); - const adapter = new LiteLLMAPIFormat(ctx.modelName, ctx.provider.baseUrl); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - ...ctx.sharedOpts, - }); - log(`[Proxy] Created LiteLLM handler (composed): ${ctx.modelName} (${ctx.provider.baseUrl})`); - return handler; - }, -}; -/** - * Vertex AI — supports two modes: - * 1. Express Mode (VERTEX_API_KEY) — uses the Gemini API endpoint with a Vertex key. - * Uses GeminiProviderTransport (with the gemini provider config) + GeminiAPIFormat. - * 2. OAuth Mode (VERTEX_PROJECT) — full project-based access with OAuth tokens. - * Uses VertexProviderTransport + publisher-specific format (Gemini/Anthropic/Default). - * - * Returns null if neither key nor project config is available. - */ -const vertexProfile: ProviderProfile = { - createHandler(ctx) { - const hasApiKey = !!process.env.VERTEX_API_KEY; - const vertexConfig = getVertexConfig(); - - if (hasApiKey) { - // Express Mode — Vertex Express uses the standard Gemini API endpoint - // but with VERTEX_API_KEY instead of GEMINI_API_KEY. - // Must use the Gemini provider config (which has the correct baseUrl/apiPath) - // because the vertex provider config has empty baseUrl/apiPath (designed for OAuth mode). - const geminiConfig = getRegisteredRemoteProviders().find((p) => p.name === "gemini"); - const expressProvider = geminiConfig || ctx.provider; - const transport = new GeminiProviderTransport( - expressProvider, - ctx.modelName, - process.env.VERTEX_API_KEY! - ); - const adapter = new GeminiAPIFormat(ctx.modelName); - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - ...ctx.sharedOpts, - }); - log(`[Proxy] Created Vertex AI Express handler (composed): ${ctx.modelName}`); - return handler; - } + case "anthropic": + return new AnthropicProviderTransport(rp, apiKey); + + case "kimi-coding": + return new KimiCodingTransport(rp, apiKey); + + case "openrouter": + return new OpenRouterProviderTransport(apiKey, modelName); + + case "ollamacloud": + return new OllamaProviderTransport(rp, apiKey); + + case "litellm": + if (!rp.baseUrl) { + logStderr("Error: LITELLM_BASE_URL or --litellm-url is required for LiteLLM provider."); + return null; + } + return new LiteLLMProviderTransport(rp.baseUrl, apiKey, modelName); - if (vertexConfig) { - // OAuth Mode — ComposedHandler with publisher-specific adapter + case "vertex": { + if (process.env.VERTEX_API_KEY) { + // Express mode: use Gemini transport with VERTEX_API_KEY + const geminiDef = getProviderByName("google"); + const geminiConfig = geminiDef ? toRemoteProvider(geminiDef) : rp; + return new GeminiProviderTransport(geminiConfig, modelName, process.env.VERTEX_API_KEY); + } + const vertexConfig = getVertexConfig(); + if (!vertexConfig) { + log("[Proxy] Vertex AI requires either VERTEX_API_KEY or VERTEX_PROJECT"); + return null; + } const oauthError = validateVertexOAuthConfig(); if (oauthError) { log(`[Proxy] Vertex OAuth config error: ${oauthError}`); return null; } - const parsed = parseVertexModel(ctx.modelName); - const transport = new VertexProviderTransport(vertexConfig, parsed); - - let adapter: BaseModelAdapter; - if (parsed.publisher === "google") { - adapter = new GeminiAPIFormat(ctx.modelName); - } else if (parsed.publisher === "anthropic") { - adapter = new AnthropicAPIFormat(parsed.model, "vertex"); - } else { - // Mistral/Meta use OpenAI format; Mistral rawPredict uses bare model name - const modelId = - parsed.publisher === "mistralai" ? parsed.model : `${parsed.publisher}/${parsed.model}`; - adapter = new DefaultAPIFormat(modelId); + const parsed = parseVertexModel(modelName); + return new VertexProviderTransport(vertexConfig, parsed); + } + + case "poe": { + const poeApiKey = process.env.POE_API_KEY; + if (!poeApiKey) { + log(`[Proxy] POE_API_KEY not set, cannot use Poe model: ${modelName}`); + return null; } + return new PoeProvider(poeApiKey); + } + + case "local": + // Handled separately in createHandlerForProvider (needs LocalProvider config) + return null; + + default: + return null; + } +} + +// --------------------------------------------------------------------------- +// Format adapter resolution +// --------------------------------------------------------------------------- - const handler = new ComposedHandler(transport, ctx.targetModel, ctx.modelName, ctx.port, { - adapter, - ...ctx.sharedOpts, - }); - log( - `[Proxy] Created Vertex AI OAuth handler (composed): ${ctx.modelName} [${parsed.publisher}] (project: ${vertexConfig.projectId})` - ); - return handler; +export function resolveAPIFormat( + def: ProviderDefinition, + modelName: string, +): BaseAPIFormat | null { + switch (def.transport) { + case "gemini": + case "gemini-oauth": + return new GeminiAPIFormat(modelName); + + case "openai": { + if (isZenProvider(def) && modelName.toLowerCase().startsWith("gpt-")) + return new CodexAPIFormat(modelName); + return new OpenAIAPIFormat(modelName); } - log(`[Proxy] Vertex AI requires either VERTEX_API_KEY or VERTEX_PROJECT`); - return null; - }, -}; + case "anthropic": + return new AnthropicAPIFormat(modelName, def.name); + + case "kimi-coding": + return new AnthropicAPIFormat(modelName, def.name); + + case "ollamacloud": + return new OllamaAPIFormat(modelName); + + case "litellm": { + const rp = toRemoteProvider(def); + return new LiteLLMAPIFormat(modelName, rp.baseUrl); + } + + case "vertex": { + if (process.env.VERTEX_API_KEY) { + // Express mode uses Gemini format + return new GeminiAPIFormat(modelName); + } + const parsed = parseVertexModel(modelName); + if (parsed.publisher === "google") return new GeminiAPIFormat(modelName); + if (parsed.publisher === "anthropic") return new AnthropicAPIFormat(parsed.model, "vertex"); + const id = parsed.publisher === "mistralai" + ? parsed.model + : `${parsed.publisher}/${parsed.model}`; + return new DefaultAPIFormat(id); + } + + case "openrouter": + return new OpenRouterAPIFormat(modelName); + + case "poe": + return new OpenAIAPIFormat(modelName); + + case "local": + return null; // Handled separately in createHandlerForProvider + + default: + return null; + } +} // --------------------------------------------------------------------------- -// Profile table +// Middleware resolution +// --------------------------------------------------------------------------- + +function resolveMiddlewares(modelId: string): ModelMiddleware[] { + const mw: ModelMiddleware[] = []; + if (matchesModelFamily(modelId, "gemini") || modelId.toLowerCase().includes("google/")) { + mw.push(new GeminiThoughtSignatureMiddleware()); + } + return mw; +} + +// --------------------------------------------------------------------------- +// Model dialect resolution +// --------------------------------------------------------------------------- + +export function resolveModelDialect(modelId: string): BaseModelDialect | null { + const m = matchesModelFamily; + + if (m(modelId, "grok") || modelId.toLowerCase().includes("x-ai/")) + return new GrokModelDialect(modelId); + if (m(modelId, "gemini") || modelId.toLowerCase().includes("google/")) + return new GeminiModelDialect(modelId); + if (m(modelId, "qwen") || m(modelId, "alibaba")) + return new QwenModelDialect(modelId); + if (m(modelId, "minimax")) + return new MiniMaxModelDialect(modelId); + if (m(modelId, "deepseek")) + return new DeepSeekModelDialect(modelId); + if (m(modelId, "glm-") || m(modelId, "chatglm-") || modelId.toLowerCase().includes("zhipu/")) + return new GLMModelDialect(modelId); + if (m(modelId, "xiaomi") || m(modelId, "mimo")) + return new XiaomiModelDialect(modelId); + + return null; +} + +// --------------------------------------------------------------------------- +// Public API // --------------------------------------------------------------------------- /** - * Maps provider name (as returned by resolveRemoteProvider().provider.name) to its profile. + * Resolve the effective definition and API key for special cases. * - * Lookup is O(1). Add new providers here — no changes to proxy-server.ts needed. + * Handles: + * - OpenCode Zen minimax variant swap (zen + minimax model -> zen-minimax def) + * - publicKeyFallback for providers that allow auth-less access (e.g. Zen "public") */ -export const PROVIDER_PROFILES: Record = { - gemini: geminiProfile, - "gemini-codeassist": geminiCodeAssistProfile, - openai: openaiProfile, - minimax: anthropicCompatProfile, - "minimax-coding": anthropicCompatProfile, - kimi: anthropicCompatProfile, - "kimi-coding": anthropicCompatProfile, - zai: anthropicCompatProfile, - glm: glmProfile, - "glm-coding": glmProfile, - "opencode-zen": openCodeZenProfile, - "opencode-zen-go": openCodeZenProfile, - ollamacloud: ollamaCloudProfile, - litellm: litellmProfile, - vertex: vertexProfile, -}; +export function resolveEffective( + def: ProviderDefinition, + modelName: string, + apiKey: string, +): { def: ProviderDefinition; apiKey: string } { + let effectiveDef = def; -// --------------------------------------------------------------------------- -// Public factory -// --------------------------------------------------------------------------- + // Zen + minimax: swap to the anthropic-transport variant definition. + // The zen-minimax defs have transport: "anthropic", so resolveTransport and + // resolveAPIFormat handle them via the standard anthropic case branches. + if (def.name === "opencode-zen" && modelName.toLowerCase().includes("minimax")) { + effectiveDef = getProviderByName("opencode-zen-minimax") ?? def; + } else if (def.name === "opencode-zen-go" && modelName.toLowerCase().includes("minimax")) { + effectiveDef = getProviderByName("opencode-zen-go-minimax") ?? def; + } + + if (effectiveDef.publicKeyFallback && !apiKey) { + return { def: effectiveDef, apiKey: effectiveDef.publicKeyFallback }; + } + return { def: effectiveDef, apiKey }; +} /** - * Create a ModelHandler for the given resolved provider using the profile table. + * Create a ModelHandler for a given ProviderDefinition. + * + * Resolves transport + format adapter + model dialect, composes into ComposedHandler. + * Returns null if the provider cannot be wired (missing config, unsupported transport). * - * Returns null when: - * - The provider name is not in PROVIDER_PROFILES (unknown provider) - * - The profile's createHandler() returns null (e.g. missing config) + * Handles all provider types including OpenRouter, Poe, and local providers. */ -export function createHandlerForProvider(ctx: ProfileContext): ModelHandler | null { - const profile = PROVIDER_PROFILES[ctx.provider.name]; - if (!profile) { - return null; // Unknown provider — caller should fall through to OpenRouter or return null +export function createHandlerForProvider( + def: ProviderDefinition, + modelName: string, + apiKey: string, + targetModel: string, + port: number, + opts?: Pick & { + concurrency?: number; + }, +): ModelHandler | null { + const effective = resolveEffective(def, modelName, apiKey); + + // Strip "poe:" prefix from model name (native pattern match preserves it) + let effectiveModelName = modelName; + if (effective.def.transport === "poe") { + effectiveModelName = modelName.replace(/^poe:/, ""); } - // Log API key provenance so debug logs show exactly which key is used and where it came from - if (ctx.provider.apiKeyEnvVar) { - const provenance = resolveApiKeyProvenance(ctx.provider.apiKeyEnvVar); + if (effective.def.apiKeyEnvVar) { + const provenance = resolveApiKeyProvenance(effective.def.apiKeyEnvVar); log(`[Proxy] API key: ${formatProvenanceLog(provenance)}`); } - log(`[Proxy] Handler: provider=${ctx.provider.name}, model=${ctx.modelName}`); + log(`[Proxy] Handler: provider=${effective.def.name}, model=${effectiveModelName}`); + + // Local providers need special handling: LocalTransport takes a LocalProvider config + if (effective.def.transport === "local") { + return createLocalHandler(effective.def, effectiveModelName, targetModel, port, opts); + } + + const transport = resolveTransport(effective.def, effectiveModelName, effective.apiKey); + if (!transport) return null; + + const adapter = resolveAPIFormat(effective.def, effectiveModelName); + const dialect = resolveModelDialect(effectiveModelName); + const middlewares = resolveMiddlewares(effectiveModelName); + + return new ComposedHandler(transport, targetModel, effectiveModelName, port, { + formatAdapter: adapter ?? undefined, + modelDialect: dialect, + middlewares, + isInteractive: opts?.isInteractive, + invocationMode: opts?.invocationMode, + }); +} + +/** + * Create a handler for a local provider (ollama, lmstudio, vllm, mlx, custom URL). + * + * Local providers use the provider-registry to resolve the LocalProvider config, + * which includes the correct base URL (with env var overrides) and API path. + */ +/** + * Select the right LocalModelAdapter subclass based on model family. + */ +function createLocalAdapter(modelName: string, providerName: string): LocalModelAdapter { + const m = matchesModelFamily; + if (m(modelName, "qwen")) return new LocalQwenFormatAdapter(modelName, providerName); + if (m(modelName, "deepseek")) return new LocalDeepSeekFormatAdapter(modelName, providerName); + if (m(modelName, "llama")) return new LocalLlamaFormatAdapter(modelName, providerName); + if (m(modelName, "mistral")) return new LocalMistralFormatAdapter(modelName, providerName); + return new LocalModelAdapter(modelName, providerName); +} + +/** + * Create the right LocalTransport subclass based on provider name. + * Ollama and LM Studio have provider-specific health checks, context window + * detection, and payload fields. Everything else uses the generic LocalTransport. + */ +function createLocalTransport( + config: LocalProviderConfig, + modelName: string, + options?: { concurrency?: number }, +): LocalTransport { + switch (config.name) { + case "ollama": + return new OllamaTransport(config, modelName, options); + case "lmstudio": + return new LMStudioTransport(config, modelName, options); + default: + return new LocalTransport(config, modelName, options); + } +} + +function createLocalHandler( + def: ProviderDefinition, + modelName: string, + targetModel: string, + port: number, + opts?: Pick & { + concurrency?: number; + }, +): ModelHandler | null { + // Try prefix-based local provider resolution (ollama/, lmstudio/, etc.) + const resolved = resolveProvider(targetModel); + if (resolved) { + const provider = createLocalTransport(resolved.provider, resolved.modelName, { + concurrency: opts?.concurrency ?? resolved.concurrency, + }); + const adapter = createLocalAdapter(resolved.modelName, resolved.provider.name); + // For local models, targetModel = modelName (no provider prefix to preserve) + const handler = new ComposedHandler(provider, resolved.modelName, resolved.modelName, port, { + formatAdapter: adapter, + tokenStrategy: "local", + summarizeTools: opts?.summarizeTools, + isInteractive: opts?.isInteractive, + invocationMode: opts?.invocationMode, + }); + log( + `[Proxy] Created local provider handler: ${resolved.provider.name}/${resolved.modelName}${(opts?.concurrency ?? resolved.concurrency) !== undefined ? ` (concurrency: ${opts?.concurrency ?? resolved.concurrency})` : ""}` + ); + return handler; + } + + // Try URL-based model (http://localhost:11434/llama3) + const urlParsed = parseUrlModel(targetModel); + if (urlParsed) { + const providerConfig = createUrlProvider(urlParsed); + const provider = createLocalTransport(providerConfig, urlParsed.modelName); + const adapter = createLocalAdapter(urlParsed.modelName, providerConfig.name); + const handler = new ComposedHandler( + provider, + urlParsed.modelName, + urlParsed.modelName, + port, + { + formatAdapter: adapter, + tokenStrategy: "local", + summarizeTools: opts?.summarizeTools, + isInteractive: opts?.isInteractive, + invocationMode: opts?.invocationMode, + } + ); + log( + `[Proxy] Created URL-based local provider handler: ${urlParsed.baseUrl}/${urlParsed.modelName}` + ); + return handler; + } - return profile.createHandler(ctx); + return null; } diff --git a/packages/cli/src/providers/provider-resolver.ts b/packages/cli/src/providers/provider-resolver.ts index 00d28592..3fe32dd6 100644 --- a/packages/cli/src/providers/provider-resolver.ts +++ b/packages/cli/src/providers/provider-resolver.ts @@ -28,7 +28,6 @@ import { existsSync } from "node:fs"; import { join } from "node:path"; import { homedir } from "node:os"; import { resolveProvider, parseUrlModel } from "./provider-registry.js"; -import { resolveRemoteProvider } from "./remote-provider-registry.js"; import { autoRoute, getAutoRouteHint } from "./auto-route.js"; import { parseModelSpec, @@ -40,6 +39,8 @@ import { import { getApiKeyInfo as getApiKeyInfoFromDefs, getDisplayName as getDisplayNameFromDefs, + getProviderByName, + getProviderDefinitionByRemoteName, } from "./provider-definitions.js"; /** @@ -84,6 +85,10 @@ export interface ProviderResolution { wasAutoRouted?: boolean; /** Human-readable auto-routing decision message */ autoRouteMessage?: string; + /** The resolved ProviderDefinition (for direct handoff to createHandlerForProvider) */ + def?: import("./provider-definitions.js").ProviderDefinition; + /** The resolved API key value (for direct handoff to createHandlerForProvider) */ + apiKey?: string; } /** @@ -295,6 +300,7 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol // 4. Check for explicit OpenRouter routing if (parsed.provider === "openrouter") { const info = API_KEY_INFO.openrouter; + const orDef = getProviderByName("openrouter"); return addCommonFields({ category: "openrouter", providerName: "OpenRouter", @@ -304,6 +310,8 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol apiKeyAvailable: isApiKeyAvailable(info), apiKeyDescription: info.description, apiKeyUrl: info.url, + def: orDef, + apiKey: info.envVar ? (process.env[info.envVar] || "") : "", }); } @@ -315,6 +323,7 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol if (autoResult) { if (autoResult.provider === "litellm") { const info = API_KEY_INFO.litellm; + const llDef = getProviderByName("litellm"); return addCommonFields({ category: "direct-api", providerName: "LiteLLM", @@ -326,11 +335,14 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol apiKeyUrl: info.url, wasAutoRouted: true, autoRouteMessage: autoResult.displayMessage, + def: llDef, + apiKey: info.envVar ? (process.env[info.envVar] || "") : "", }); } if (autoResult.provider === "openrouter") { const info = API_KEY_INFO.openrouter; + const orDef = getProviderByName("openrouter"); return addCommonFields({ category: "openrouter", providerName: "OpenRouter", @@ -342,6 +354,8 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol apiKeyUrl: info.url, wasAutoRouted: true, autoRouteMessage: autoResult.displayMessage, + def: orDef, + apiKey: info.envVar ? (process.env[info.envVar] || "") : "", }); } @@ -350,29 +364,42 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol } } - // 6. Try to resolve as direct API provider - const remoteResolved = resolveRemoteProvider(modelId); - if (remoteResolved) { - const provider = remoteResolved.provider; - + // 6. Try to resolve as direct API provider using the parsed provider name + // Handle google -> gemini naming difference used by RemoteProvider + const remoteName = parsed.provider === "google" ? "gemini" : parsed.provider; + const remoteDef = getProviderByName(parsed.provider) || getProviderDefinitionByRemoteName(remoteName); + if (remoteDef && !remoteDef.isLocal && remoteDef.baseUrl !== "" && remoteDef.name !== "qwen" && remoteDef.name !== "native-anthropic") { // Provider-specific prefix found - check if provider's API key is available - const info = API_KEY_INFO[provider.name] || { - envVar: provider.apiKeyEnvVar, - description: `${provider.name} API Key`, + const info = API_KEY_INFO[remoteName] || { + envVar: remoteDef.apiKeyEnvVar, + description: `${remoteName} API Key`, url: "", }; const providerDisplayName = - PROVIDER_DISPLAY_NAMES[provider.name] || - provider.name.charAt(0).toUpperCase() + provider.name.slice(1); + PROVIDER_DISPLAY_NAMES[remoteName] || + remoteName.charAt(0).toUpperCase() + remoteName.slice(1); const wasAutoRouted = !parsed.isExplicitProvider; - // Return direct-api resolution — report missing key instead of silent fallback + // Resolve the actual API key value for direct handoff + let resolvedApiKey = ""; + if (info.envVar && process.env[info.envVar]) { + resolvedApiKey = process.env[info.envVar]!; + } else if (info.aliases) { + for (const alias of info.aliases) { + if (process.env[alias]) { + resolvedApiKey = process.env[alias]!; + break; + } + } + } + + // Return direct-api resolution with parsed model name (prefix already stripped by parser) return addCommonFields({ category: "direct-api", providerName: providerDisplayName, - modelName: remoteResolved.modelName, + modelName: parsed.model, fullModelId: modelId, requiredApiKeyEnvVar: info.envVar || null, apiKeyAvailable: isApiKeyAvailable(info), @@ -382,6 +409,8 @@ export function resolveModelProvider(modelId: string | undefined): ProviderResol autoRouteMessage: wasAutoRouted ? (pendingAutoRouteMessage ?? `Auto-routed: ${parsed.model} -> ${providerDisplayName}`) : undefined, + def: remoteDef, + apiKey: resolvedApiKey, }); } diff --git a/packages/cli/src/providers/provider-routing.test.ts b/packages/cli/src/providers/provider-routing.test.ts index 7fdf9c13..43985e43 100644 --- a/packages/cli/src/providers/provider-routing.test.ts +++ b/packages/cli/src/providers/provider-routing.test.ts @@ -9,10 +9,10 @@ import { describe, test, expect } from "bun:test"; import { parseModelSpec } from "./model-parser.js"; -import { BUILTIN_PROVIDERS, getShortcuts } from "./provider-definitions.js"; -import { DialectManager } from "../adapters/dialect-manager.js"; +import { BUILTIN_PROVIDERS, PROVIDER_SHORTCUTS } from "./provider-definitions.js"; +import { resolveModelDialect } from "../providers/provider-profiles.js"; import { GrokModelDialect } from "../adapters/grok-model-dialect.js"; -import { GeminiAPIFormat } from "../adapters/gemini-api-format.js"; +import { GeminiModelDialect } from "../adapters/gemini-model-dialect.js"; import { QwenModelDialect } from "../adapters/qwen-model-dialect.js"; import { DeepSeekModelDialect } from "../adapters/deepseek-model-dialect.js"; import { GLMModelDialect } from "../adapters/glm-model-dialect.js"; @@ -21,7 +21,8 @@ import { XiaomiModelDialect } from "../adapters/xiaomi-model-dialect.js"; import { CodexAPIFormat } from "../adapters/codex-api-format.js"; import { OpenAIAPIFormat } from "../adapters/openai-api-format.js"; import { DefaultAPIFormat } from "../adapters/base-api-format.js"; -import { PROVIDER_PROFILES, createHandlerForProvider } from "./provider-profiles.js"; +import { createHandlerForProvider } from "./provider-profiles.js"; +import { getProviderByName } from "./provider-definitions.js"; import { OpenAIProviderTransport } from "./transport/openai.js"; // --------------------------------------------------------------------------- @@ -29,7 +30,7 @@ import { OpenAIProviderTransport } from "./transport/openai.js"; // --------------------------------------------------------------------------- describe("parseModelSpec — shortcut resolution", () => { - const shortcuts = getShortcuts(); + const shortcuts = PROVIDER_SHORTCUTS; test("every shortcut in BUILTIN_PROVIDERS resolves to the correct provider", () => { for (const def of BUILTIN_PROVIDERS) { @@ -161,127 +162,127 @@ describe("parseModelSpec — native model auto-detection", () => { // Section 2: Adapter selection // --------------------------------------------------------------------------- -describe("DialectManager — correct dialect selection", () => { +describe("resolveModelDialect — correct dialect selection", () => { test("grok-beta → GrokModelDialect", () => { - const adapter = new DialectManager("grok-beta").getAdapter(); + const adapter = resolveModelDialect("grok-beta"); expect(adapter).toBeInstanceOf(GrokModelDialect); }); test("x-ai/grok-beta → GrokModelDialect", () => { - const adapter = new DialectManager("x-ai/grok-beta").getAdapter(); + const adapter = resolveModelDialect("x-ai/grok-beta"); expect(adapter).toBeInstanceOf(GrokModelDialect); }); - test("gemini-2.0-flash → GeminiAPIFormat", () => { - const adapter = new DialectManager("gemini-2.0-flash").getAdapter(); - expect(adapter).toBeInstanceOf(GeminiAPIFormat); + test("gemini-2.0-flash → GeminiModelDialect", () => { + const adapter = resolveModelDialect("gemini-2.0-flash"); + expect(adapter).toBeInstanceOf(GeminiModelDialect); }); - test("google/gemini-2.5-pro → GeminiAPIFormat", () => { - const adapter = new DialectManager("google/gemini-2.5-pro").getAdapter(); - expect(adapter).toBeInstanceOf(GeminiAPIFormat); + test("google/gemini-2.5-pro → GeminiModelDialect", () => { + const adapter = resolveModelDialect("google/gemini-2.5-pro"); + expect(adapter).toBeInstanceOf(GeminiModelDialect); }); test("deepseek-r1 → DeepSeekModelDialect", () => { - const adapter = new DialectManager("deepseek-r1").getAdapter(); + const adapter = resolveModelDialect("deepseek-r1"); expect(adapter).toBeInstanceOf(DeepSeekModelDialect); }); test("glm-5 → GLMModelDialect", () => { - const adapter = new DialectManager("glm-5").getAdapter(); + const adapter = resolveModelDialect("glm-5"); expect(adapter).toBeInstanceOf(GLMModelDialect); }); test("zhipu/glm-4 → GLMModelDialect", () => { - const adapter = new DialectManager("zhipu/glm-4").getAdapter(); + const adapter = resolveModelDialect("zhipu/glm-4"); expect(adapter).toBeInstanceOf(GLMModelDialect); }); test("minimax-m2.5 → MiniMaxModelDialect", () => { - const adapter = new DialectManager("minimax-m2.5").getAdapter(); + const adapter = resolveModelDialect("minimax-m2.5"); expect(adapter).toBeInstanceOf(MiniMaxModelDialect); }); test("qwen3-coder → QwenModelDialect", () => { - const adapter = new DialectManager("qwen3-coder").getAdapter(); + const adapter = resolveModelDialect("qwen3-coder"); expect(adapter).toBeInstanceOf(QwenModelDialect); }); test("xiaomi/mimo-vl-2b → XiaomiModelDialect", () => { - const adapter = new DialectManager("xiaomi/mimo-vl-2b").getAdapter(); + const adapter = resolveModelDialect("xiaomi/mimo-vl-2b"); expect(adapter).toBeInstanceOf(XiaomiModelDialect); }); - test("codex-mini → CodexAPIFormat", () => { - const adapter = new DialectManager("codex-mini").getAdapter(); - expect(adapter).toBeInstanceOf(CodexAPIFormat); + test("codex-mini → null (no dialect, wire format only)", () => { + const adapter = resolveModelDialect("codex-mini"); + expect(adapter).toBeNull(); }); - test("gpt-4o → DefaultAPIFormat (GPT models use default OpenAI format)", () => { - const adapter = new DialectManager("gpt-4o").getAdapter(); - expect(adapter).toBeInstanceOf(DefaultAPIFormat); + test("gpt-4o → null (no dialect quirks)", () => { + const adapter = resolveModelDialect("gpt-4o"); + expect(adapter).toBeNull(); }); - test("o3-mini → OpenAIAPIFormat (o-series needs reasoning_effort mapping)", () => { - const adapter = new DialectManager("o3-mini").getAdapter(); - expect(adapter).toBeInstanceOf(OpenAIAPIFormat); + test("o3-mini → null (no dialect quirks)", () => { + const adapter = resolveModelDialect("o3-mini"); + expect(adapter).toBeNull(); }); - test("unknown-model → DefaultAPIFormat", () => { - const adapter = new DialectManager("unknown-model").getAdapter(); - expect(adapter).toBeInstanceOf(DefaultAPIFormat); + test("unknown-model → null", () => { + const adapter = resolveModelDialect("unknown-model"); + expect(adapter).toBeNull(); }); }); -describe("DialectManager — false positive prevention", () => { +describe("resolveModelDialect — false positive prevention", () => { test("qwen-grok-hybrid → QwenModelDialect (NOT GrokModelDialect)", () => { - const adapter = new DialectManager("qwen-grok-hybrid").getAdapter(); + const adapter = resolveModelDialect("qwen-grok-hybrid"); expect(adapter).toBeInstanceOf(QwenModelDialect); expect(adapter).not.toBeInstanceOf(GrokModelDialect); }); test("deepseek-glm-test → DeepSeekModelDialect (NOT GLMModelDialect)", () => { - const adapter = new DialectManager("deepseek-glm-test").getAdapter(); + const adapter = resolveModelDialect("deepseek-glm-test"); expect(adapter).toBeInstanceOf(DeepSeekModelDialect); expect(adapter).not.toBeInstanceOf(GLMModelDialect); }); - test("my-grok-clone → DefaultAPIFormat (not GrokModelDialect — grok is mid-string)", () => { - const adapter = new DialectManager("my-grok-clone").getAdapter(); + test("my-grok-clone → null (not GrokModelDialect — grok is mid-string)", () => { + const adapter = resolveModelDialect("my-grok-clone"); expect(adapter).not.toBeInstanceOf(GrokModelDialect); // Should fall to default since none of the specific families match - expect(adapter).toBeInstanceOf(DefaultAPIFormat); + expect(adapter).toBeNull(); }); - test("my-minimax-clone → DefaultAPIFormat (not MiniMaxModelDialect)", () => { - const adapter = new DialectManager("my-minimax-clone").getAdapter(); + test("my-minimax-clone → null (not MiniMaxModelDialect)", () => { + const adapter = resolveModelDialect("my-minimax-clone"); expect(adapter).not.toBeInstanceOf(MiniMaxModelDialect); - expect(adapter).toBeInstanceOf(DefaultAPIFormat); + expect(adapter).toBeNull(); }); - test("test-deepseek-model → DefaultAPIFormat (not DeepSeekModelDialect — deepseek is mid-string)", () => { - const adapter = new DialectManager("test-deepseek-model").getAdapter(); + test("test-deepseek-model → null (not DeepSeekModelDialect — deepseek is mid-string)", () => { + const adapter = resolveModelDialect("test-deepseek-model"); expect(adapter).not.toBeInstanceOf(DeepSeekModelDialect); - expect(adapter).toBeInstanceOf(DefaultAPIFormat); + expect(adapter).toBeNull(); }); test("vendor/grok-beta uses GrokModelDialect (vendor prefix is fine)", () => { - const adapter = new DialectManager("vendor/grok-beta").getAdapter(); + const adapter = resolveModelDialect("vendor/grok-beta"); expect(adapter).toBeInstanceOf(GrokModelDialect); }); test("vendor/deepseek-r1 uses DeepSeekModelDialect (vendor prefix)", () => { - const adapter = new DialectManager("vendor/deepseek-r1").getAdapter(); + const adapter = resolveModelDialect("vendor/deepseek-r1"); expect(adapter).toBeInstanceOf(DeepSeekModelDialect); }); test("vendor/minimax-m2.5 uses MiniMaxModelDialect (vendor prefix)", () => { - const adapter = new DialectManager("vendor/minimax-m2.5").getAdapter(); + const adapter = resolveModelDialect("vendor/minimax-m2.5"); expect(adapter).toBeInstanceOf(MiniMaxModelDialect); }); test("openrouter/x-ai/grok-beta uses GrokModelDialect (double vendor prefix)", () => { - const adapter = new DialectManager("openrouter/x-ai/grok-beta").getAdapter(); + const adapter = resolveModelDialect("openrouter/x-ai/grok-beta"); expect(adapter).toBeInstanceOf(GrokModelDialect); }); }); @@ -290,34 +291,32 @@ describe("DialectManager — false positive prevention", () => { // Section 3: Provider profiles // --------------------------------------------------------------------------- -describe("PROVIDER_PROFILES — coverage", () => { - test("every entry in PROVIDER_PROFILES has a matching BUILTIN_PROVIDER", () => { - for (const profileName of Object.keys(PROVIDER_PROFILES)) { - // Profile names match RemoteProvider.name which maps google→gemini - const builtinName = profileName === "gemini" ? "google" : profileName; - const def = BUILTIN_PROVIDERS.find((d) => d.name === builtinName || d.name === profileName); - expect(def).toBeDefined(); - } - }); - - test("all remote BUILTIN_PROVIDERS have a profile (except openrouter, poe, qwen, native-anthropic)", () => { - // openrouter has its own dedicated handler (not ComposedHandler), poe has transport but no profile yet - const skipProviders = new Set([ - "qwen", - "native-anthropic", - "poe", - "openrouter", - "xai", // auto-routed through OpenRouter - "ollama", - "lmstudio", - "vllm", - "mlx", - ]); +describe("createHandlerForProvider — coverage", () => { + // Providers that are wired through createHandlerForProvider (ComposedHandler). + // Excludes: openrouter (dedicated handler), poe (dedicated handler), + // local providers (dedicated handler), qwen/native-anthropic (virtual, auto-routed). + const skipProviders = new Set([ + "qwen", + "native-anthropic", + "poe", + "openrouter", + "xai", // auto-routed through OpenRouter + "ollama", + "lmstudio", + "vllm", + "mlx", + "vertex", // requires VERTEX_PROJECT or VERTEX_API_KEY env + "litellm", // requires LITELLM_BASE_URL env + ]); + + test("all remote BUILTIN_PROVIDERS create a handler via createHandlerForProvider", () => { for (const def of BUILTIN_PROVIDERS) { if (skipProviders.has(def.name)) continue; - const profileName = def.name === "google" ? "gemini" : def.name; - const profile = PROVIDER_PROFILES[profileName]; - expect(profile).toBeDefined(); + const handler = createHandlerForProvider( + def, "test-model", "test-key", "test-target", 4000, + { isInteractive: false, invocationMode: "explicit-model" }, + ); + expect(handler).not.toBeNull(); } }); }); @@ -405,24 +404,17 @@ describe("matchesModelFamily", () => { // --------------------------------------------------------------------------- describe("OpenCode Zen — model routing", () => { + const zenDef = getProviderByName("opencode-zen")!; const zenBaseProvider = { name: "opencode-zen" as const, baseUrl: "https://opencode.ai/zen", apiPath: "/v1/chat/completions", apiKeyEnvVar: "OPENCODE_API_KEY", - prefixes: [], + prefixes: [] as string[], headers: undefined, authScheme: undefined, }; - const sharedCtx = { - provider: zenBaseProvider, - apiKey: "test-key", - targetModel: "placeholder", - port: 4000, - sharedOpts: { isInteractive: false as const, invocationMode: "explicit-model" as const }, - }; - test("GPT model routes to Responses API endpoint (/v1/responses)", () => { // The transport for GPT models via Zen must point to /v1/responses, not /v1/chat/completions. const responsesProvider = { ...zenBaseProvider, apiPath: "/v1/responses" }; @@ -436,20 +428,26 @@ describe("OpenCode Zen — model routing", () => { }); test("GPT model createHandler returns non-null", () => { - const profile = PROVIDER_PROFILES["opencode-zen"]; - const handler = profile.createHandler({ ...sharedCtx, modelName: "gpt-4o" }); + const handler = createHandlerForProvider( + zenDef, "gpt-4o", "test-key", "placeholder", 4000, + { isInteractive: false, invocationMode: "explicit-model" }, + ); expect(handler).not.toBeNull(); }); test("MiniMax model createHandler returns non-null", () => { - const profile = PROVIDER_PROFILES["opencode-zen"]; - const handler = profile.createHandler({ ...sharedCtx, modelName: "minimax-m2.5" }); + const handler = createHandlerForProvider( + zenDef, "minimax-m2.5", "test-key", "placeholder", 4000, + { isInteractive: false, invocationMode: "explicit-model" }, + ); expect(handler).not.toBeNull(); }); test("GLM model createHandler returns non-null (default OpenAI path)", () => { - const profile = PROVIDER_PROFILES["opencode-zen"]; - const handler = profile.createHandler({ ...sharedCtx, modelName: "glm-5" }); + const handler = createHandlerForProvider( + zenDef, "glm-5", "test-key", "placeholder", 4000, + { isInteractive: false, invocationMode: "explicit-model" }, + ); expect(handler).not.toBeNull(); }); diff --git a/packages/cli/src/providers/remote-provider-registry.ts b/packages/cli/src/providers/remote-provider-registry.ts deleted file mode 100644 index dd7d5a62..00000000 --- a/packages/cli/src/providers/remote-provider-registry.ts +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Remote Provider Registry - * - * Handles resolution of remote cloud API providers (Gemini, OpenAI, MiniMax, Kimi, GLM, GLM Coding, OllamaCloud, OpenCode Zen) - * based on model ID specifications. - * - * New syntax: provider@model - * Examples: - * google@gemini-3-pro-preview - Direct Google API - * openrouter@google/gemini-3-pro - Explicit OpenRouter - * oai@gpt-5.3 - Direct OpenAI API (shortcut) - * - * Legacy prefix patterns (deprecated, still supported): - * - g/, gemini/ -> Google Gemini API (direct) - * - go/ -> Google Gemini Code Assist (OAuth) - * - oai/ -> OpenAI API (openai/ routes to OpenRouter) - * - mmax/, mm/ -> MiniMax API (Anthropic-compatible) - * - mmc/ -> MiniMax Coding Plan API (Anthropic-compatible) - * - kimi/, moonshot/ -> Kimi/Moonshot API (Anthropic-compatible) - * - glm/, zhipu/ -> GLM/Zhipu API (OpenAI-compatible) - * - gc/ -> GLM Coding Plan API (OpenAI-compatible) - * - zai/ -> Z.AI API (Anthropic-compatible) - * - oc/ -> OllamaCloud API (OpenAI-compatible) - * - zen/ -> OpenCode Zen API (OpenAI-compatible + Anthropic for MiniMax) - * - or/, no prefix with "/" -> OpenRouter (existing handler) - */ - -import type { - RemoteProvider, - ResolvedRemoteProvider, -} from "../handlers/shared/remote-provider-types.js"; -import { parseModelSpec, isLocalProviderName } from "./model-parser.js"; -import { getAllProviders, toRemoteProvider } from "./provider-definitions.js"; - -/** - * Remote provider configurations — derived from BUILTIN_PROVIDERS. - * Filters out local-only and virtual providers (qwen, native-anthropic). - */ -const getRemoteProviders = (): RemoteProvider[] => { - return getAllProviders() - .filter( - (def) => - !def.isLocal && def.baseUrl !== "" && def.name !== "qwen" && def.name !== "native-anthropic" - ) - .map(toRemoteProvider); -}; - -/** - * Resolve a model ID to a remote provider - * - * Supports both new syntax (provider@model) and legacy syntax (prefix/model) - * Returns null if no provider matches (falls through to OpenRouter default) - */ -export function resolveRemoteProvider(modelId: string): ResolvedRemoteProvider | null { - const providers = getRemoteProviders(); - - // Try new model parser first - const parsed = parseModelSpec(modelId); - - // Skip local providers - they're handled by provider-registry.ts - if (isLocalProviderName(parsed.provider)) { - return null; - } - - // Skip custom URL providers - if (parsed.provider === "custom-url") { - return null; - } - - // Look up provider by canonical name (toRemoteProvider maps "google" → "gemini" for compat) - // Try both the parsed provider name and the RemoteProvider name (which may differ, e.g. google→gemini) - const mappedName = parsed.provider === "google" ? "gemini" : parsed.provider; - const provider = providers.find((p) => p.name === mappedName || p.name === parsed.provider); - if (provider) { - return { - provider, - modelName: parsed.model, - isLegacySyntax: parsed.isLegacySyntax, - }; - } - - // Legacy: check prefix patterns for backwards compatibility - for (const provider of providers) { - for (const prefix of provider.prefixes) { - if (modelId.startsWith(prefix)) { - return { - provider, - modelName: modelId.slice(prefix.length), - isLegacySyntax: true, - }; - } - } - } - - return null; -} - -/** - * Check if a model ID explicitly routes to a remote provider (has a known prefix) - */ -export function hasRemoteProviderPrefix(modelId: string): boolean { - return resolveRemoteProvider(modelId) !== null; -} - -/** - * Get the provider type for a model ID - * Returns "gemini", "openai", "openrouter", or null - */ -export function getRemoteProviderType(modelId: string): string | null { - const resolved = resolveRemoteProvider(modelId); - return resolved?.provider.name || null; -} - -/** - * Validate that the required API key is set for a provider - * Returns error message if validation fails, null if OK - */ -export function validateRemoteProviderApiKey(provider: RemoteProvider): string | null { - // Skip validation for OAuth-based providers (empty apiKeyEnvVar) - if (provider.apiKeyEnvVar === "") { - return null; - } - - const apiKey = process.env[provider.apiKeyEnvVar]; - - if (!apiKey) { - const examples: Record = { - GEMINI_API_KEY: - "export GEMINI_API_KEY='your-key' (get from https://aistudio.google.com/app/apikey)", - OPENAI_API_KEY: - "export OPENAI_API_KEY='sk-...' (get from https://platform.openai.com/api-keys)", - OPENROUTER_API_KEY: - "export OPENROUTER_API_KEY='sk-or-...' (get from https://openrouter.ai/keys)", - MINIMAX_API_KEY: "export MINIMAX_API_KEY='your-key' (get from https://www.minimaxi.com/)", - MINIMAX_CODING_API_KEY: - "export MINIMAX_CODING_API_KEY='your-key' (get from https://platform.minimax.io/user-center/basic-information/interface-key)", - MOONSHOT_API_KEY: - "export MOONSHOT_API_KEY='your-key' (get from https://platform.moonshot.cn/)", - KIMI_CODING_API_KEY: - "export KIMI_CODING_API_KEY='sk-kimi-...' (get from https://kimi.com/code membership page, or run: claudish --kimi-login)", - ZHIPU_API_KEY: "export ZHIPU_API_KEY='your-key' (get from https://open.bigmodel.cn/)", - GLM_CODING_API_KEY: "export GLM_CODING_API_KEY='your-key' (get from https://z.ai/subscribe)", - OLLAMA_API_KEY: "export OLLAMA_API_KEY='your-key' (get from https://ollama.com/account)", - OPENCODE_API_KEY: "export OPENCODE_API_KEY='your-key' (get from https://opencode.ai/)", - }; - - const example = examples[provider.apiKeyEnvVar] || `export ${provider.apiKeyEnvVar}='your-key'`; - return `Missing ${provider.apiKeyEnvVar} environment variable.\n\nSet it with:\n ${example}`; - } - - return null; -} - -/** - * Get all registered remote providers - */ -export function getRegisteredRemoteProviders(): RemoteProvider[] { - return getRemoteProviders(); -} diff --git a/packages/cli/src/providers/transport/anthropic-compat.ts b/packages/cli/src/providers/transport/anthropic-compat.ts index 28760cd7..2f42e041 100644 --- a/packages/cli/src/providers/transport/anthropic-compat.ts +++ b/packages/cli/src/providers/transport/anthropic-compat.ts @@ -6,19 +6,21 @@ * anthropic-version, plus Kimi OAuth fallback for kimi-coding. */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; import type { RemoteProvider } from "../../handlers/shared/remote-provider-types.js"; -import { log } from "../../logger.js"; +import { BaseTransport } from "./base.js"; -export class AnthropicProviderTransport implements ProviderTransport { +export class AnthropicProviderTransport extends BaseTransport { readonly name: string; readonly displayName: string; readonly streamFormat: StreamFormat = "anthropic-sse"; + readonly tokenStrategy = "standard" as const; private provider: RemoteProvider; - private apiKey: string; + protected apiKey: string; constructor(provider: RemoteProvider, apiKey: string) { + super(provider.name); this.provider = provider; this.apiKey = apiKey; this.name = provider.name; @@ -45,35 +47,6 @@ export class AnthropicProviderTransport implements ProviderTransport { Object.assign(headers, this.provider.headers); } - // Kimi Coding: prefer API key auth, fall back to OAuth if no key provided - if (this.provider.name === "kimi-coding" && !this.apiKey) { - try { - const { existsSync, readFileSync } = await import("node:fs"); - const { join } = await import("node:path"); - const { homedir } = await import("node:os"); - - const credPath = join(homedir(), ".claudish", "kimi-oauth.json"); - if (existsSync(credPath)) { - const data = JSON.parse(readFileSync(credPath, "utf-8")); - if (data.access_token && data.refresh_token) { - const { KimiOAuth } = await import("../../auth/kimi-oauth.js"); - const oauth = KimiOAuth.getInstance(); - const accessToken = await oauth.getAccessToken(); - - // Replace API key auth with Bearer token - delete headers["x-api-key"]; - headers["Authorization"] = `Bearer ${accessToken}`; - - // Add Kimi-specific platform headers - const platformHeaders = oauth.getPlatformHeaders(); - Object.assign(headers, platformHeaders); - } - } - } catch (e: any) { - log(`[${this.displayName}] OAuth fallback failed: ${e.message}`); - } - } - return headers; } diff --git a/packages/cli/src/providers/transport/base.ts b/packages/cli/src/providers/transport/base.ts new file mode 100644 index 00000000..0750ebc9 --- /dev/null +++ b/packages/cli/src/providers/transport/base.ts @@ -0,0 +1,97 @@ +/** + * BaseTransport — abstract base for all ProviderTransport implementations. + * + * Owns a RequestQueue for rate limiting. Queue behavior is configured + * via method overrides, not external hook functions: + * - onResponse(): rate limit detection (override for Gemini, OpenRouter) + * - shouldRetry(): 5xx retry logic (override for local OOM) + * + * Subclass ApiKeyTransport for Bearer/x-api-key auth. + * Subclass OAuthTransport for token-based auth with refresh. + */ + +import type { ProviderTransport, StreamFormat } from "./types.js"; +import { RequestQueue, type QueueState, type RequestQueueConfig } from "../../handlers/shared/request-queue.js"; + +export abstract class BaseTransport implements ProviderTransport { + abstract readonly name: string; + abstract readonly displayName: string; + abstract readonly streamFormat: StreamFormat; + readonly tokenStrategy?: ProviderTransport["tokenStrategy"]; + + protected queue: RequestQueue; + + constructor(label: string, queueConfig?: Omit) { + this.queue = new RequestQueue(label, { + ...queueConfig, + onResponse: (response, state) => this.onResponse(response, state), + shouldRetry: (response, errorBody) => this.shouldRetry(response, errorBody), + calculateDelay: (state) => this.calculateDelay(state), + }); + } + + abstract getEndpoint(model?: string): string; + abstract getHeaders(): Promise>; + + async enqueueRequest(fetchFn: () => Promise): Promise { + return this.queue.enqueue(fetchFn); + } + + /** Called after each response. Return true if rate limited. Default: 429 detection. */ + protected onResponse(response: Response, _state: QueueState): boolean { + return response.status === 429; + } + + /** Called on 5xx. Return true to retry once. Default: no retry. */ + protected shouldRetry(_response: Response, _errorBody: string): boolean { + return false; + } + + /** Override delay calculation. Default: base delay with error backoff. */ + protected calculateDelay(state: QueueState): number { + let d = state.baseDelayMs; + if (state.consecutiveErrors > 0) d = Math.min(d * (1 + state.consecutiveErrors * 0.5), state.maxDelayMs); + return d; + } +} + +export abstract class ApiKeyTransport extends BaseTransport { + protected apiKey: string; + protected baseUrl: string; + protected apiPath: string; + + constructor( + label: string, + baseUrl: string, + apiPath: string, + apiKey: string, + queueConfig?: Omit, + ) { + super(label, queueConfig); + this.apiKey = apiKey; + this.baseUrl = baseUrl; + this.apiPath = apiPath; + } + + getEndpoint(_model?: string): string { + return `${this.baseUrl}${this.apiPath}`; + } + + async getHeaders(): Promise> { + return { + "Content-Type": "application/json", + "Authorization": `Bearer ${this.apiKey}`, + }; + } +} + +export abstract class OAuthTransport extends BaseTransport { + protected accessToken = ""; + + async getHeaders(): Promise> { + return { + "Content-Type": "application/json", + "Authorization": `Bearer ${this.accessToken}`, + }; + } +} diff --git a/packages/cli/src/providers/transport/gemini-apikey.ts b/packages/cli/src/providers/transport/gemini-apikey.ts index e6105884..ebed62d0 100644 --- a/packages/cli/src/providers/transport/gemini-apikey.ts +++ b/packages/cli/src/providers/transport/gemini-apikey.ts @@ -4,25 +4,30 @@ * Transport concerns: * - x-goog-api-key header * - Endpoint URL with {model} substitution - * - GeminiRequestQueue for rate limiting + * - Rate limiting via onResponse override (parses quotaResetDelay from 429) * - gemini-sse stream format */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; import type { RemoteProvider } from "../../handlers/shared/remote-provider-types.js"; -import { GeminiRequestQueue } from "../../handlers/shared/gemini-queue.js"; -import { log } from "../../logger.js"; +import { BaseTransport } from "./base.js"; +import type { QueueState } from "../../handlers/shared/request-queue.js"; -export class GeminiProviderTransport implements ProviderTransport { +export class GeminiProviderTransport extends BaseTransport { readonly name = "gemini"; readonly displayName = "Gemini API"; readonly streamFormat: StreamFormat = "gemini-sse"; + readonly tokenStrategy = "standard" as const; private provider: RemoteProvider; private apiKey: string; private modelName: string; constructor(provider: RemoteProvider, modelName: string, apiKey: string) { + super("gemini", { + baseDelayMs: 1000, + maxDelayMs: 10000, + }); this.provider = provider; this.modelName = modelName; this.apiKey = apiKey; @@ -39,13 +44,33 @@ export class GeminiProviderTransport implements ProviderTransport { }; } - /** - * Rate-limited request via GeminiRequestQueue singleton. - * Serializes all Gemini requests to prevent quota exhaustion. - */ - async enqueueRequest(fetchFn: () => Promise): Promise { - const queue = GeminiRequestQueue.getInstance(); - return queue.enqueue(fetchFn); + protected override onResponse(response: Response, state: QueueState): boolean { + if (response.status !== 429) return false; + try { + response.clone().text().then(text => { + const data = JSON.parse(text); + const detail = data?.error?.details?.find((d: any) => d.quotaResetDelay); + if (detail?.quotaResetDelay) { + const match = detail.quotaResetDelay.match(/(\d+(?:\.\d+)?)/); + if (match) { + state.currentDelayMs = Math.min( + Math.max(Math.ceil(parseFloat(match[1]) * 1000), state.baseDelayMs), + state.maxDelayMs, + ); + } + } + }).catch(() => {}); + } catch {} + return true; + } + + getNonStreamingEndpoint(_model?: string): string { + const apiPath = this.provider.apiPath + .replace("{model}", this.modelName) + .replace(":streamGenerateContent", ":generateContent"); + const url = new URL(apiPath, this.provider.baseUrl); + url.searchParams.delete("alt"); + return url.toString(); } } diff --git a/packages/cli/src/providers/transport/gemini-codeassist.ts b/packages/cli/src/providers/transport/gemini-codeassist.ts index 43b24591..61a6afde 100644 --- a/packages/cli/src/providers/transport/gemini-codeassist.ts +++ b/packages/cli/src/providers/transport/gemini-codeassist.ts @@ -6,28 +6,34 @@ * - Project ID via setupGeminiUser() * - Fixed endpoint: cloudcode-pa.googleapis.com/v1internal:streamGenerateContent?alt=sse * - Wraps payload in CodeAssist envelope: {model, project, user_prompt_id, request: } - * - GeminiRequestQueue for rate limiting + * - Rate limiting via onResponse override (parses quotaResetDelay from 429) * - gemini-sse stream format (with response wrapper) */ import { randomUUID } from "node:crypto"; -import type { ProviderTransport, StreamFormat } from "./types.js"; -import { GeminiRequestQueue } from "../../handlers/shared/gemini-queue.js"; +import type { StreamFormat } from "./types.js"; +import { OAuthTransport } from "./base.js"; +import type { QueueState } from "../../handlers/shared/request-queue.js"; import { log } from "../../logger.js"; const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com/v1internal:streamGenerateContent?alt=sse"; -export class GeminiCodeAssistProviderTransport implements ProviderTransport { +export class GeminiCodeAssistProviderTransport extends OAuthTransport { readonly name = "gemini-codeassist"; readonly displayName = "Gemini Free"; readonly streamFormat: StreamFormat = "gemini-sse"; + readonly tokenStrategy = "standard" as const; + readonly unwrapResponse = true; private modelName: string; - private accessToken: string | null = null; private projectId: string | null = null; constructor(modelName: string) { + super("gemini-codeassist", { + baseDelayMs: 1000, + maxDelayMs: 10000, + }); this.modelName = modelName; } @@ -35,10 +41,24 @@ export class GeminiCodeAssistProviderTransport implements ProviderTransport { return CODE_ASSIST_ENDPOINT; } - async getHeaders(): Promise> { - return { - Authorization: `Bearer ${this.accessToken}`, - }; + protected override onResponse(response: Response, state: QueueState): boolean { + if (response.status !== 429) return false; + try { + response.clone().text().then(text => { + const data = JSON.parse(text); + const detail = data?.error?.details?.find((d: any) => d.quotaResetDelay); + if (detail?.quotaResetDelay) { + const match = detail.quotaResetDelay.match(/(\d+(?:\.\d+)?)/); + if (match) { + state.currentDelayMs = Math.min( + Math.max(Math.ceil(parseFloat(match[1]) * 1000), state.baseDelayMs), + state.maxDelayMs, + ); + } + } + }).catch(() => {}); + } catch {} + return true; } /** @@ -66,14 +86,6 @@ export class GeminiCodeAssistProviderTransport implements ProviderTransport { request: payload, }; } - - /** - * Rate-limited request via GeminiRequestQueue singleton. - */ - async enqueueRequest(fetchFn: () => Promise): Promise { - const queue = GeminiRequestQueue.getInstance(); - return queue.enqueue(fetchFn); - } } // Backward-compatible alias diff --git a/packages/cli/src/providers/transport/kimi-coding.ts b/packages/cli/src/providers/transport/kimi-coding.ts new file mode 100644 index 00000000..ae585765 --- /dev/null +++ b/packages/cli/src/providers/transport/kimi-coding.ts @@ -0,0 +1,55 @@ +/** + * KimiCodingTransport — Kimi Coding Plan transport. + * + * Extends AnthropicProviderTransport with OAuth fallback: when no API key + * is configured, attempts to load stored OAuth credentials from + * ~/.claudish/kimi-oauth.json and use Bearer token auth instead. + */ + +import { AnthropicProviderTransport } from "./anthropic-compat.js"; +import type { RemoteProvider } from "../../handlers/shared/remote-provider-types.js"; +import { log } from "../../logger.js"; + +export class KimiCodingTransport extends AnthropicProviderTransport { + constructor(provider: RemoteProvider, apiKey: string) { + super(provider, apiKey); + } + + override async getHeaders(): Promise> { + const headers = await super.getHeaders(); + + // If an API key was provided, the base class already set auth headers + if (this.apiKey) { + return headers; + } + + // OAuth fallback: load stored credentials when no API key is set + try { + const { existsSync, readFileSync } = await import("node:fs"); + const { join } = await import("node:path"); + const { homedir } = await import("node:os"); + + const credPath = join(homedir(), ".claudish", "kimi-oauth.json"); + if (existsSync(credPath)) { + const data = JSON.parse(readFileSync(credPath, "utf-8")); + if (data.access_token && data.refresh_token) { + const { KimiOAuth } = await import("../../auth/kimi-oauth.js"); + const oauth = KimiOAuth.getInstance(); + const accessToken = await oauth.getAccessToken(); + + // Replace API key auth with Bearer token + delete headers["x-api-key"]; + headers["Authorization"] = `Bearer ${accessToken}`; + + // Add Kimi-specific platform headers + const platformHeaders = oauth.getPlatformHeaders(); + Object.assign(headers, platformHeaders); + } + } + } catch (e: any) { + log(`[${this.displayName}] OAuth fallback failed: ${e.message}`); + } + + return headers; + } +} diff --git a/packages/cli/src/providers/transport/litellm.ts b/packages/cli/src/providers/transport/litellm.ts index 93a109df..ea04e0fc 100644 --- a/packages/cli/src/providers/transport/litellm.ts +++ b/packages/cli/src/providers/transport/litellm.ts @@ -5,7 +5,8 @@ * LiteLLM uses OpenAI-compatible /v1/chat/completions endpoint. */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; +import { ApiKeyTransport } from "./base.js"; /** * Extra headers that LiteLLM should forward to specific providers. @@ -18,18 +19,15 @@ const MODEL_EXTRA_HEADERS: Array<{ pattern: string; headers: Record> { - const headers: Record = { - Authorization: `Bearer ${this.apiKey}`, - }; - return headers; - } - getExtraPayloadFields(): Record { const fields: Record = {}; diff --git a/packages/cli/src/providers/transport/lmstudio.ts b/packages/cli/src/providers/transport/lmstudio.ts new file mode 100644 index 00000000..aed83770 --- /dev/null +++ b/packages/cli/src/providers/transport/lmstudio.ts @@ -0,0 +1,69 @@ +/** + * LMStudioTransport — transport for local LM Studio instances. + * + * Overrides from LocalTransport: + * - Context window: fetches from /v1/models (context_length, max_context_length, etc.) + * - Error message: LM Studio-specific connection guidance + */ + +import type { LocalProvider as LocalProviderConfig } from "../provider-registry.js"; +import { LocalTransport } from "./local.js"; +import { log } from "../../logger.js"; + +export class LMStudioTransport extends LocalTransport { + constructor(config: LocalProviderConfig, modelName: string, options?: { concurrency?: number }) { + super(config, modelName, options); + } + + protected async fetchContextWindow(): Promise { + if (process.env.CLAUDISH_CONTEXT_WINDOW) return; + + log(`[${this.displayName}] Fetching context window...`); + const config = this.getConfig(); + + try { + const response = await fetch(`${config.baseUrl}/v1/models`, { + method: "GET", + signal: AbortSignal.timeout(3000), + }); + + if (response.ok) { + const data = (await response.json()) as any; + log(`[${this.displayName}] Models response: ${JSON.stringify(data).slice(0, 500)}`); + + const models = data.data || []; + const modelName = this.getModelName(); + const targetModel = + models.find((m: any) => m.id === modelName) || + models.find((m: any) => m.id?.endsWith(`/${modelName}`)) || + models.find((m: any) => modelName.includes(m.id)); + + if (targetModel) { + const ctxLength = + targetModel.context_length || + targetModel.max_context_length || + targetModel.context_window || + targetModel.max_tokens; + if (ctxLength && typeof ctxLength === "number") { + this.setContextWindow(ctxLength); + log(`[${this.displayName}] Context window from model: ${this.getContextWindow()}`); + return; + } + } + + this.setContextWindow(32768); + log(`[${this.displayName}] Using default context window: ${this.getContextWindow()}`); + } + } catch (e: any) { + this.setContextWindow(32768); + log( + `[${this.displayName}] Failed to fetch model info: ${e?.message || e}. Using default: ${this.getContextWindow()}` + ); + } + } + + protected getConnectionErrorMessage(): string { + const config = this.getConfig(); + return `Cannot connect to LM Studio at ${config.baseUrl}. Make sure LM Studio server is running.`; + } +} diff --git a/packages/cli/src/providers/transport/local.ts b/packages/cli/src/providers/transport/local.ts index c48184cf..68c791c9 100644 --- a/packages/cli/src/providers/transport/local.ts +++ b/packages/cli/src/providers/transport/local.ts @@ -1,18 +1,20 @@ /** - * LocalProvider — transport for local OpenAI-compatible providers. + * LocalTransport — base transport for local OpenAI-compatible providers. * - * Supports Ollama, LM Studio, vLLM, MLX, and custom local endpoints. + * Generic base for vLLM, MLX, custom URLs, and any OpenAI-compatible local endpoint. + * Provider-specific behavior lives in subclasses: + * - OllamaTransport: /api/tags health check, /api/show context window, num_ctx injection + * - LMStudioTransport: /v1/models context window detection * * Transport concerns: - * - Health checks (Ollama /api/tags → /v1/models fallback) - * - Context window auto-detection (Ollama /api/show, LM Studio /v1/models) + * - Health checks (/v1/models for generic providers) * - Custom undici agent with 10-minute timeouts for slow local inference * - LocalModelQueue for GPU concurrency control - * - Provider-specific error messages */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; import type { LocalProvider as LocalProviderConfig } from "../../providers/provider-registry.js"; +import { BaseTransport } from "./base.js"; import { LocalModelQueue } from "../../handlers/shared/local-queue.js"; import { log } from "../../logger.js"; import { Agent } from "undici"; @@ -34,19 +36,20 @@ const DISPLAY_NAMES: Record = { custom: "Custom", }; -export class LocalTransport implements ProviderTransport { +export class LocalTransport extends BaseTransport { readonly name: string; readonly displayName: string; readonly streamFormat: StreamFormat = "openai-sse"; + readonly tokenStrategy = "local" as const; - private config: LocalProviderConfig; - private modelName: string; + protected config: LocalProviderConfig; + protected modelName: string; private concurrency?: number; private healthChecked = false; - private isHealthy = false; private _contextWindow = 32768; constructor(config: LocalProviderConfig, modelName: string, options?: { concurrency?: number }) { + super(config.name); this.config = config; this.modelName = modelName; this.name = config.name; @@ -87,12 +90,6 @@ export class LocalTransport implements ProviderTransport { } getExtraPayloadFields(): Record { - // Ollama defaults to 2048 context and silently truncates — set it explicitly - if (this.config.name === "ollama") { - const numCtx = Math.max(this._contextWindow, 32768); - log(`[${this.displayName}] Setting num_ctx: ${numCtx} (detected: ${this._contextWindow})`); - return { options: { num_ctx: numCtx } }; - } return {}; } @@ -113,6 +110,7 @@ export class LocalTransport implements ProviderTransport { throw new Error(this.getConnectionErrorMessage()); } + this.healthChecked = true; await this.fetchContextWindow(); } @@ -125,32 +123,24 @@ export class LocalTransport implements ProviderTransport { return this.config; } - // ─── Health checks ────────────────────────────────────────────────── - - private async checkHealth(): Promise { - if (this.healthChecked) return this.isHealthy; + /** Expose model name for subclass access */ + getModelName(): string { + return this.modelName; + } - // Try Ollama-specific health check first - try { - const healthUrl = `${this.config.baseUrl}/api/tags`; - log(`[${this.displayName}] Trying health check: ${healthUrl}`); - const response = await fetch(healthUrl, { - method: "GET", - signal: AbortSignal.timeout(5000), - }); + /** Allow subclasses to update the context window */ + protected setContextWindow(value: number): void { + this._contextWindow = value; + } - if (response.ok) { - this.isHealthy = true; - this.healthChecked = true; - log(`[${this.displayName}] Health check passed (/api/tags)`); - return true; - } - log(`[${this.displayName}] /api/tags returned ${response.status}, trying /v1/models`); - } catch (e: any) { - log(`[${this.displayName}] /api/tags failed: ${e?.message || e}, trying /v1/models`); - } + // --- Health checks --- - // Try generic OpenAI-compatible health check + /** + * Generic health check: tries /v1/models. + * Subclasses can override to try provider-specific endpoints first, + * then call super.checkHealth() as fallback. + */ + protected async checkHealth(): Promise { try { const modelsUrl = `${this.config.baseUrl}/v1/models`; log(`[${this.displayName}] Trying health check: ${modelsUrl}`); @@ -159,8 +149,6 @@ export class LocalTransport implements ProviderTransport { signal: AbortSignal.timeout(5000), }); if (response.ok) { - this.isHealthy = true; - this.healthChecked = true; log(`[${this.displayName}] Health check passed (/v1/models)`); return true; } @@ -169,119 +157,29 @@ export class LocalTransport implements ProviderTransport { log(`[${this.displayName}] /v1/models failed: ${e?.message || e}`); } - this.healthChecked = true; - this.isHealthy = false; log(`[${this.displayName}] Health check FAILED - provider not available`); return false; } - // ─── Context window auto-detection ────────────────────────────────── + // --- Context window auto-detection --- - private async fetchContextWindow(): Promise { - // Skip if env var already set + /** + * Fetch context window from the provider API. + * Base implementation is a no-op (uses default). Subclasses override for + * provider-specific detection (Ollama /api/show, LM Studio /v1/models). + */ + protected async fetchContextWindow(): Promise { if (process.env.CLAUDISH_CONTEXT_WINDOW) return; - log(`[${this.displayName}] Fetching context window...`); - if (this.config.name === "ollama") { - await this.fetchOllamaContextWindow(); - } else if (this.config.name === "lmstudio") { - await this.fetchLMStudioContextWindow(); - } else { - log( - `[${this.displayName}] No context window fetch for this provider, using default: ${this._contextWindow}` - ); - } - } - - private async fetchOllamaContextWindow(): Promise { - try { - const response = await fetch(`${this.config.baseUrl}/api/show`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ name: this.modelName }), - signal: AbortSignal.timeout(3000), - }); - - if (response.ok) { - const data = (await response.json()) as any; - let ctxFromInfo = data.model_info?.["general.context_length"]; - - // Search for {arch}.context_length if not found at general.context_length - if (!ctxFromInfo && data.model_info) { - for (const key of Object.keys(data.model_info)) { - if (key.endsWith(".context_length")) { - ctxFromInfo = data.model_info[key]; - break; - } - } - } - - const ctxFromParams = data.parameters?.match(/num_ctx\s+(\d+)/)?.[1]; - if (ctxFromInfo) { - this._contextWindow = parseInt(String(ctxFromInfo), 10); - } else if (ctxFromParams) { - this._contextWindow = parseInt(ctxFromParams, 10); - } else { - log(`[${this.displayName}] No context info found, using default: ${this._contextWindow}`); - } - if (ctxFromInfo || ctxFromParams) { - log(`[${this.displayName}] Context window: ${this._contextWindow}`); - } - } - } catch { - // Use default context window - } - } - - private async fetchLMStudioContextWindow(): Promise { - try { - const response = await fetch(`${this.config.baseUrl}/v1/models`, { - method: "GET", - signal: AbortSignal.timeout(3000), - }); - - if (response.ok) { - const data = (await response.json()) as any; - log(`[${this.displayName}] Models response: ${JSON.stringify(data).slice(0, 500)}`); - - const models = data.data || []; - const targetModel = - models.find((m: any) => m.id === this.modelName) || - models.find((m: any) => m.id?.endsWith(`/${this.modelName}`)) || - models.find((m: any) => this.modelName.includes(m.id)); - - if (targetModel) { - const ctxLength = - targetModel.context_length || - targetModel.max_context_length || - targetModel.context_window || - targetModel.max_tokens; - if (ctxLength && typeof ctxLength === "number") { - this._contextWindow = ctxLength; - log(`[${this.displayName}] Context window from model: ${this._contextWindow}`); - return; - } - } - - this._contextWindow = 32768; - log(`[${this.displayName}] Using default context window: ${this._contextWindow}`); - } - } catch (e: any) { - this._contextWindow = 32768; - log( - `[${this.displayName}] Failed to fetch model info: ${e?.message || e}. Using default: ${this._contextWindow}` - ); - } + log( + `[${this.displayName}] No context window fetch for this provider, using default: ${this._contextWindow}` + ); } - // ─── Error messages ───────────────────────────────────────────────── + // --- Error messages --- - private getConnectionErrorMessage(): string { + protected getConnectionErrorMessage(): string { switch (this.config.name) { - case "ollama": - return `Cannot connect to Ollama at ${this.config.baseUrl}. Make sure Ollama is running with: ollama serve`; - case "lmstudio": - return `Cannot connect to LM Studio at ${this.config.baseUrl}. Make sure LM Studio server is running.`; case "vllm": return `Cannot connect to vLLM at ${this.config.baseUrl}. Make sure vLLM server is running.`; default: diff --git a/packages/cli/src/providers/transport/ollama.ts b/packages/cli/src/providers/transport/ollama.ts new file mode 100644 index 00000000..8e693a8a --- /dev/null +++ b/packages/cli/src/providers/transport/ollama.ts @@ -0,0 +1,101 @@ +/** + * OllamaTransport — transport for local Ollama instances. + * + * Overrides from LocalTransport: + * - Health check: tries /api/tags first (Ollama-native), then /v1/models fallback + * - Context window: fetches from /api/show (model_info, parameters) + * - Extra payload: injects num_ctx to prevent Ollama's silent 2048 truncation + * - Error message: Ollama-specific connection guidance + */ + +import type { LocalProvider as LocalProviderConfig } from "../provider-registry.js"; +import { LocalTransport } from "./local.js"; +import { log } from "../../logger.js"; + +export class OllamaTransport extends LocalTransport { + constructor(config: LocalProviderConfig, modelName: string, options?: { concurrency?: number }) { + super(config, modelName, options); + } + + getExtraPayloadFields(): Record { + const ctxWindow = this.getContextWindow(); + const numCtx = Math.max(ctxWindow, 32768); + log(`[${this.displayName}] Setting num_ctx: ${numCtx} (detected: ${ctxWindow})`); + return { options: { num_ctx: numCtx } }; + } + + protected async checkHealth(): Promise { + const config = this.getConfig(); + + // Try Ollama-native /api/tags first + try { + const healthUrl = `${config.baseUrl}/api/tags`; + log(`[${this.displayName}] Trying health check: ${healthUrl}`); + const response = await fetch(healthUrl, { + method: "GET", + signal: AbortSignal.timeout(5000), + }); + + if (response.ok) { + log(`[${this.displayName}] Health check passed (/api/tags)`); + return true; + } + log(`[${this.displayName}] /api/tags returned ${response.status}, trying /v1/models`); + } catch (e: any) { + log(`[${this.displayName}] /api/tags failed: ${e?.message || e}, trying /v1/models`); + } + + // Fall back to generic OpenAI-compatible check + return super.checkHealth(); + } + + protected async fetchContextWindow(): Promise { + if (process.env.CLAUDISH_CONTEXT_WINDOW) return; + + log(`[${this.displayName}] Fetching context window...`); + const config = this.getConfig(); + + try { + const response = await fetch(`${config.baseUrl}/api/show`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ name: this.getModelName() }), + signal: AbortSignal.timeout(3000), + }); + + if (response.ok) { + const data = (await response.json()) as any; + let ctxFromInfo = data.model_info?.["general.context_length"]; + + // Search for {arch}.context_length if not found at general.context_length + if (!ctxFromInfo && data.model_info) { + for (const key of Object.keys(data.model_info)) { + if (key.endsWith(".context_length")) { + ctxFromInfo = data.model_info[key]; + break; + } + } + } + + const ctxFromParams = data.parameters?.match(/num_ctx\s+(\d+)/)?.[1]; + if (ctxFromInfo) { + this.setContextWindow(parseInt(String(ctxFromInfo), 10)); + } else if (ctxFromParams) { + this.setContextWindow(parseInt(ctxFromParams, 10)); + } else { + log(`[${this.displayName}] No context info found, using default: ${this.getContextWindow()}`); + } + if (ctxFromInfo || ctxFromParams) { + log(`[${this.displayName}] Context window: ${this.getContextWindow()}`); + } + } + } catch { + // Use default context window + } + } + + protected getConnectionErrorMessage(): string { + const config = this.getConfig(); + return `Cannot connect to Ollama at ${config.baseUrl}. Make sure Ollama is running with: ollama serve`; + } +} diff --git a/packages/cli/src/providers/transport/ollamacloud.ts b/packages/cli/src/providers/transport/ollamacloud.ts index 831182a5..253cca67 100644 --- a/packages/cli/src/providers/transport/ollamacloud.ts +++ b/packages/cli/src/providers/transport/ollamacloud.ts @@ -5,27 +5,22 @@ * Uses Bearer token auth and Ollama's native JSONL streaming format. */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; import type { RemoteProvider } from "../../handlers/shared/remote-provider-types.js"; +import { ApiKeyTransport } from "./base.js"; -export class OllamaProviderTransport implements ProviderTransport { +export class OllamaProviderTransport extends ApiKeyTransport { readonly name = "ollamacloud"; readonly displayName = "OllamaCloud"; readonly streamFormat: StreamFormat = "ollama-jsonl"; - - private provider: RemoteProvider; - private apiKey: string; + readonly tokenStrategy = "accumulate-both" as const; constructor(provider: RemoteProvider, apiKey: string) { - this.provider = provider; - this.apiKey = apiKey; - } - - getEndpoint(): string { - return `${this.provider.baseUrl}${this.provider.apiPath}`; + super("ollamacloud", provider.baseUrl, provider.apiPath, apiKey); } async getHeaders(): Promise> { + // Only include Authorization if key is present (matches original behavior) const headers: Record = {}; if (this.apiKey) { headers["Authorization"] = `Bearer ${this.apiKey}`; diff --git a/packages/cli/src/providers/transport/openai.ts b/packages/cli/src/providers/transport/openai.ts index 47d2cabb..ba66723b 100644 --- a/packages/cli/src/providers/transport/openai.ts +++ b/packages/cli/src/providers/transport/openai.ts @@ -3,23 +3,26 @@ * * Handles communication with OpenAI's API (and OpenAI-compatible providers * like GLM, Zen). Supports both Chat Completions and Codex Responses API. - * Includes 30-second timeout with detailed error reporting. + * Includes custom 429 retry logic with exponential backoff (up to 5 retries). */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; import type { RemoteProvider } from "../../handlers/shared/remote-provider-types.js"; +import { BaseTransport } from "./base.js"; import { log } from "../../logger.js"; -export class OpenAIProviderTransport implements ProviderTransport { +export class OpenAIProviderTransport extends BaseTransport { readonly name: string; readonly displayName: string; readonly streamFormat: StreamFormat; + readonly tokenStrategy = "delta-aware" as const; private provider: RemoteProvider; - private apiKey: string; + protected apiKey: string; private modelName: string; constructor(provider: RemoteProvider, modelName: string, apiKey: string) { + super(provider.name); this.provider = provider; this.modelName = modelName; this.apiKey = apiKey; @@ -48,8 +51,8 @@ export class OpenAIProviderTransport implements ProviderTransport { } /** - * Override fetch with 30-second timeout, 429 retry with exponential backoff, - * and detailed error handling. + * Override fetch with 429 retry with exponential backoff and detailed error handling. + * Bypasses the base RequestQueue to implement its own retry logic. */ async enqueueRequest(fetchFn: () => Promise): Promise { const maxRetries = 5; diff --git a/packages/cli/src/providers/transport/openrouter.ts b/packages/cli/src/providers/transport/openrouter.ts index 424f270b..96f6c21c 100644 --- a/packages/cli/src/providers/transport/openrouter.ts +++ b/packages/cli/src/providers/transport/openrouter.ts @@ -4,36 +4,40 @@ * Transport concerns: * - Bearer token auth * - OpenRouter-specific headers (HTTP-Referer, X-Title) - * - OpenRouterRequestQueue for rate limiting + * - Rate limiting via onResponse/calculateDelay overrides (X-RateLimit headers) * - openai-sse stream format * - Context window lookup from cached OpenRouter model catalog */ -import type { ProviderTransport, StreamFormat } from "./types.js"; -import { OpenRouterRequestQueue } from "../../handlers/shared/openrouter-queue.js"; +import type { StreamFormat } from "./types.js"; +import { BaseTransport } from "./base.js"; +import type { QueueState } from "../../handlers/shared/request-queue.js"; import { getCachedOpenRouterModels } from "../../model-loader.js"; const OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"; -export class OpenRouterProviderTransport implements ProviderTransport { +export class OpenRouterProviderTransport extends BaseTransport { readonly name = "openrouter"; readonly displayName = "OpenRouter"; readonly streamFormat: StreamFormat = "openai-sse"; private apiKey: string; - private queue: OpenRouterRequestQueue; private modelId: string; + // Rate limit state from X-RateLimit headers + private limitRequests: number | null = null; + private remainingRequests: number | null = null; + private resetTime: number | null = null; + constructor(apiKey: string, modelId?: string) { + super("openrouter", { + baseDelayMs: 1000, + maxDelayMs: 10000, + }); this.apiKey = apiKey; this.modelId = modelId || ""; - this.queue = OpenRouterRequestQueue.getInstance(); } - /** - * OpenRouter normalizes all responses to OpenAI SSE format server-side, - * regardless of the underlying model (even if the adapter declares anthropic-sse). - */ overrideStreamFormat(): StreamFormat { return "openai-sse"; } @@ -50,19 +54,46 @@ export class OpenRouterProviderTransport implements ProviderTransport { }; } - async enqueueRequest(fetchFn: () => Promise): Promise { - return this.queue.enqueue(fetchFn); - } - - /** - * Look up context window from the cached OpenRouter model catalog. - * The catalog is pre-warmed at startup via warmAllCatalogs(). - */ getContextWindow(): number { const models = this.modelId ? getCachedOpenRouterModels() : null; const model = models?.find((m: any) => m.id === this.modelId); return model?.context_length || model?.top_provider?.context_length || 200_000; } + + protected override onResponse(response: Response, state: QueueState): boolean { + const lr = response.headers.get("X-RateLimit-Limit-Requests"); + if (lr) this.limitRequests = parseInt(lr, 10); + const rr = response.headers.get("X-RateLimit-Remaining-Requests"); + if (rr) this.remainingRequests = parseInt(rr, 10); + const r = response.headers.get("X-RateLimit-Reset-Requests"); + if (r) this.resetTime = parseFloat(r); + + if (response.status === 429) { + this.remainingRequests = 0; + const ra = response.headers.get("Retry-After"); + if (ra) { + const s = parseInt(ra, 10); + if (!isNaN(s)) state.currentDelayMs = Math.max(state.currentDelayMs, Math.min(s * 1000, state.maxDelayMs)); + } + return true; + } + return false; + } + + protected override calculateDelay(state: QueueState): number { + let delayMs = state.baseDelayMs; + if (this.remainingRequests !== null && this.limitRequests !== null && this.limitRequests > 0) { + const pct = this.remainingRequests / this.limitRequests; + if (pct < 0.2) delayMs = Math.max(delayMs, 3000); + else if (pct < 0.5) delayMs = Math.max(delayMs, 2000); + } + if (this.resetTime !== null && this.remainingRequests !== null && this.remainingRequests > 0) { + const until = this.resetTime - Date.now() / 1000; + if (until > 0) delayMs = Math.max(delayMs, Math.min((until * 1000) / this.remainingRequests, state.maxDelayMs)); + } + if (state.consecutiveErrors > 0) delayMs *= 1 + state.consecutiveErrors * 0.5; + return Math.min(delayMs, state.maxDelayMs); + } } // Backward-compatible alias diff --git a/packages/cli/src/providers/transport/poe.ts b/packages/cli/src/providers/transport/poe.ts index c77caec3..0fec0d52 100644 --- a/packages/cli/src/providers/transport/poe.ts +++ b/packages/cli/src/providers/transport/poe.ts @@ -7,28 +7,18 @@ * - Standard OpenAI SSE format */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; +import { ApiKeyTransport } from "./base.js"; -const POE_API_URL = "https://api.poe.com/v1/chat/completions"; +const POE_API_URL = "https://api.poe.com"; +const POE_API_PATH = "/v1/chat/completions"; -export class PoeProvider implements ProviderTransport { +export class PoeProvider extends ApiKeyTransport { readonly name = "poe"; readonly displayName = "Poe"; readonly streamFormat: StreamFormat = "openai-sse"; - private apiKey: string; - constructor(apiKey: string) { - this.apiKey = apiKey; - } - - getEndpoint(): string { - return POE_API_URL; - } - - async getHeaders(): Promise> { - return { - Authorization: `Bearer ${this.apiKey}`, - }; + super("poe", POE_API_URL, POE_API_PATH, apiKey); } } diff --git a/packages/cli/src/providers/transport/types.ts b/packages/cli/src/providers/transport/types.ts index eb24058a..4a7bbd3a 100644 --- a/packages/cli/src/providers/transport/types.ts +++ b/packages/cli/src/providers/transport/types.ts @@ -30,9 +30,20 @@ export interface ProviderTransport { /** Which stream parser to use for this provider's responses */ readonly streamFormat: StreamFormat; + /** Token counting strategy for this transport */ + readonly tokenStrategy?: "standard" | "accumulate-both" | "delta-aware" | "actual-cost" | "local"; + /** Get the full API endpoint URL for a request */ getEndpoint(model?: string): string; + /** Get the non-streaming endpoint URL, if different from streaming. + * Used by Gemini (:generateContent vs :streamGenerateContent). */ + getNonStreamingEndpoint?(model?: string): string; + + /** Whether the response body is wrapped in an envelope that must be unwrapped. + * Used by CodeAssist ({response: {...}}) and similar. */ + unwrapResponse?: boolean; + /** Get HTTP headers (may be async for OAuth token refresh) */ getHeaders(): Promise>; diff --git a/packages/cli/src/providers/transport/vertex-oauth.ts b/packages/cli/src/providers/transport/vertex-oauth.ts index 34b566ec..9c127086 100644 --- a/packages/cli/src/providers/transport/vertex-oauth.ts +++ b/packages/cli/src/providers/transport/vertex-oauth.ts @@ -12,7 +12,8 @@ * - 30s request timeout */ -import type { ProviderTransport, StreamFormat } from "./types.js"; +import type { StreamFormat } from "./types.js"; +import { OAuthTransport } from "./base.js"; import { getVertexAuthManager, buildVertexOAuthEndpoint, @@ -27,8 +28,8 @@ export interface ParsedVertexModel { /** * Parse vertex model string into publisher and model. - * "gemini-2.5-flash" → { publisher: "google", model: "gemini-2.5-flash" } - * "anthropic/claude-3-5-sonnet" → { publisher: "anthropic", model: "claude-3-5-sonnet" } + * "gemini-2.5-flash" -> { publisher: "google", model: "gemini-2.5-flash" } + * "anthropic/claude-3-5-sonnet" -> { publisher: "anthropic", model: "claude-3-5-sonnet" } */ export function parseVertexModel(modelId: string): ParsedVertexModel { const parts = modelId.split("/"); @@ -38,16 +39,16 @@ export function parseVertexModel(modelId: string): ParsedVertexModel { return { publisher: parts[0], model: parts.slice(1).join("/") }; } -export class VertexProviderTransport implements ProviderTransport { +export class VertexProviderTransport extends OAuthTransport { readonly name = "vertex"; readonly displayName = "Vertex AI"; readonly streamFormat: StreamFormat; private config: VertexConfig; private parsed: ParsedVertexModel; - private accessToken?: string; constructor(config: VertexConfig, parsed: ParsedVertexModel) { + super("vertex"); this.config = config; this.parsed = parsed; @@ -70,12 +71,6 @@ export class VertexProviderTransport implements ProviderTransport { ); } - async getHeaders(): Promise> { - return { - Authorization: `Bearer ${this.accessToken}`, - }; - } - getRequestInit(): Record { return { signal: AbortSignal.timeout(30000), // 30s timeout for Vertex diff --git a/packages/cli/src/providers/transport/zen.ts b/packages/cli/src/providers/transport/zen.ts new file mode 100644 index 00000000..19956945 --- /dev/null +++ b/packages/cli/src/providers/transport/zen.ts @@ -0,0 +1,27 @@ +/** + * ZenTransport — OpenCode Zen transport. + * Extends OpenAIProviderTransport. Defaults API key to "public" for free tier. + * GPT models on Zen use the Responses API endpoint (/v1/responses). + */ + +import { OpenAIProviderTransport } from "./openai.js"; +import type { RemoteProvider } from "../../handlers/shared/remote-provider-types.js"; + +export class ZenTransport extends OpenAIProviderTransport { + private isGptModel: boolean; + private baseUrl: string; + + constructor(provider: RemoteProvider, modelName: string, apiKey: string) { + // Default to "public" key for free zen tier + super(provider, modelName, apiKey || "public"); + this.isGptModel = modelName.toLowerCase().startsWith("gpt-"); + this.baseUrl = provider.baseUrl; + } + + override getEndpoint(): string { + if (this.isGptModel) { + return `${this.baseUrl}/v1/responses`; + } + return super.getEndpoint(); + } +} diff --git a/packages/cli/src/proxy-server.ts b/packages/cli/src/proxy-server.ts index a35530ea..7eae9dce 100644 --- a/packages/cli/src/proxy-server.ts +++ b/packages/cli/src/proxy-server.ts @@ -1,41 +1,13 @@ import { Hono } from "hono"; import { cors } from "hono/cors"; import { serve } from "@hono/node-server"; -import { log, logStderr, isLoggingEnabled } from "./logger.js"; +import { log, logStderr } from "./logger.js"; import type { ProxyServer } from "./types.js"; import { NativeHandler } from "./handlers/native-handler.js"; -import { OpenRouterProviderTransport } from "./providers/transport/openrouter.js"; -import { OpenRouterAPIFormat } from "./adapters/openrouter-api-format.js"; -import { LocalTransport } from "./providers/transport/local.js"; -import { LocalModelAdapter } from "./adapters/local-adapter.js"; -import { GeminiProviderTransport } from "./providers/transport/gemini-apikey.js"; -import { GeminiCodeAssistProviderTransport } from "./providers/transport/gemini-codeassist.js"; -import { GeminiAPIFormat } from "./adapters/gemini-api-format.js"; -import { VertexProviderTransport, parseVertexModel } from "./providers/transport/vertex-oauth.js"; -import { DefaultAPIFormat } from "./adapters/base-api-format.js"; -import { PoeProvider } from "./providers/transport/poe.js"; import type { ModelHandler } from "./handlers/types.js"; -import { ComposedHandler, type ComposedHandlerOptions } from "./handlers/composed-handler.js"; -import { LiteLLMProviderTransport } from "./providers/transport/litellm.js"; -import { LiteLLMAPIFormat } from "./adapters/litellm-api-format.js"; -import { OpenAIProviderTransport } from "./providers/transport/openai.js"; -import { OpenAIAPIFormat } from "./adapters/openai-api-format.js"; -import { AnthropicProviderTransport } from "./providers/transport/anthropic-compat.js"; -import { AnthropicAPIFormat } from "./adapters/anthropic-api-format.js"; -import { OllamaProviderTransport } from "./providers/transport/ollamacloud.js"; -import { OllamaAPIFormat } from "./adapters/ollama-api-format.js"; -import { - resolveProvider, - parseUrlModel, - createUrlProvider, -} from "./providers/provider-registry.js"; +import type { ComposedHandlerOptions } from "./handlers/composed-handler.js"; import { parseModelSpec } from "./providers/model-parser.js"; -import { - resolveRemoteProvider, - validateRemoteProviderApiKey, - getRegisteredRemoteProviders, -} from "./providers/remote-provider-registry.js"; -import { getVertexConfig, validateVertexOAuthConfig } from "./auth/vertex-auth.js"; +import { getProviderDefinitionByRemoteName } from "./providers/provider-definitions.js"; import { resolveModelProvider } from "./providers/provider-resolver.js"; import { warmPricingCache } from "./services/pricing-cache.js"; import { fetchLiteLLMModels } from "./model-loader.js"; @@ -69,138 +41,21 @@ export async function createProxyServer( modelMap?: { opus?: string; sonnet?: string; haiku?: string; subagent?: string }, options: ProxyServerOptions = {} ): Promise { - // Define handlers for different roles + // Define handlers const nativeHandler = new NativeHandler(anthropicApiKey); - const openRouterHandlers = new Map(); // Map from Target Model ID -> OpenRouter Handler - const localProviderHandlers = new Map(); // Map from Target Model ID -> Local Provider Handler - const remoteProviderHandlers = new Map(); // Map from Target Model ID -> Gemini/OpenAI Handler - const poeHandlers = new Map(); // Map from Target Model ID -> Poe Handler - - // Helper to get or create OpenRouter handler for a target model - const getOpenRouterHandler = ( - targetModel: string, - invocationMode?: ComposedHandlerOptions["invocationMode"] - ): ModelHandler => { - // For explicit @ syntax: strip provider prefix (openrouter@google/gemini → google/gemini) - // For already-resolved vendor/model IDs (qwen/qwen3.5-plus-02-15): use as-is to preserve - // the vendor prefix that OpenRouter requires. parseModelSpec() would otherwise strip it - // (e.g. "qwen/" is a native pattern match → model becomes "qwen3.5-plus-02-15"). - const parsed = parseModelSpec(targetModel); - const modelId = targetModel.includes("@") ? parsed.model : targetModel; - - if (!openRouterHandlers.has(modelId)) { - const orProvider = new OpenRouterProviderTransport(openrouterApiKey || "", modelId); - const orAdapter = new OpenRouterAPIFormat(modelId); - openRouterHandlers.set( - modelId, - new ComposedHandler(orProvider, modelId, modelId, port, { - adapter: orAdapter, - isInteractive: options.isInteractive, - invocationMode, - }) - ); - } - return openRouterHandlers.get(modelId)!; - }; - - // Helper to get or create Poe handler for a target model - const getPoeHandler = ( - targetModel: string, - invocationMode?: ComposedHandlerOptions["invocationMode"] - ): ModelHandler | null => { - const poeApiKey = process.env.POE_API_KEY; - if (!poeApiKey) { - log(`[Proxy] POE_API_KEY not set, cannot use Poe model: ${targetModel}`); - return null; - } - // Strip "poe:" prefix to get the actual model name for the API - const modelId = targetModel.replace(/^poe:/, ""); - if (!poeHandlers.has(modelId)) { - const poeTransport = new PoeProvider(poeApiKey); - poeHandlers.set( - modelId, - new ComposedHandler(poeTransport, modelId, modelId, port, { - isInteractive: options.isInteractive, - invocationMode, - }) - ); - } - return poeHandlers.get(modelId)!; - }; - - // Check if model is a Poe model (has poe: prefix) - const isPoeModel = (model: string): boolean => { - return model.startsWith("poe:"); - }; - - // Helper to get or create Local Provider handler for a target model - const getLocalProviderHandler = ( - targetModel: string, - invocationMode?: ComposedHandlerOptions["invocationMode"] - ): ModelHandler | null => { - if (localProviderHandlers.has(targetModel)) { - return localProviderHandlers.get(targetModel)!; - } - - // Check for prefix-based local provider (ollama/, lmstudio/, etc.) - const resolved = resolveProvider(targetModel); - if (resolved) { - const provider = new LocalTransport(resolved.provider, resolved.modelName, { - concurrency: resolved.concurrency, - }); - const adapter = new LocalModelAdapter(resolved.modelName, resolved.provider.name); - const handler = new ComposedHandler(provider, resolved.modelName, resolved.modelName, port, { - adapter, - tokenStrategy: "local", - summarizeTools: options.summarizeTools, - isInteractive: options.isInteractive, - invocationMode, - }); - localProviderHandlers.set(targetModel, handler); - log( - `[Proxy] Created local provider handler: ${resolved.provider.name}/${resolved.modelName}${resolved.concurrency !== undefined ? ` (concurrency: ${resolved.concurrency})` : ""}` - ); - return handler; - } - - // Check for URL-based model (http://localhost:11434/llama3) - const urlParsed = parseUrlModel(targetModel); - if (urlParsed) { - const providerConfig = createUrlProvider(urlParsed); - const provider = new LocalTransport(providerConfig, urlParsed.modelName); - const adapter = new LocalModelAdapter(urlParsed.modelName, providerConfig.name); - const handler = new ComposedHandler( - provider, - urlParsed.modelName, - urlParsed.modelName, - port, - { - adapter, - tokenStrategy: "local", - summarizeTools: options.summarizeTools, - isInteractive: options.isInteractive, - invocationMode, - } - ); - localProviderHandlers.set(targetModel, handler); - log( - `[Proxy] Created URL-based local provider handler: ${urlParsed.baseUrl}/${urlParsed.modelName}` - ); - return handler; - } - - return null; - }; - - // Helper to get or create remote provider handler (Gemini, OpenAI) - // TODO: Consolidate src/ and packages/core/src/ - they're manually synced duplicates - const getRemoteProviderHandler = ( + const handlerCache = new Map(); + + /** + * Get or create a handler for a target model. Single entry point for all providers. + * + * Delegates routing to resolveModelProvider, construction to createHandlerForProvider. + * Handles auto-routing recursion (bare model -> provider@model) and caching. + */ + const getHandler = ( targetModel: string, invocationMode?: ComposedHandlerOptions["invocationMode"] ): ModelHandler | null => { - if (remoteProviderHandlers.has(targetModel)) { - return remoteProviderHandlers.get(targetModel)!; - } + if (handlerCache.has(targetModel)) return handlerCache.get(targetModel)!; // Use centralized resolver with fallback logic const resolution = resolveModelProvider(targetModel); @@ -212,63 +67,125 @@ export async function createProxyServer( log(`[Auto-route] ${resolution.autoRouteMessage}`); } - // If resolver says use OpenRouter (including fallback cases), create the handler - // directly here so we can use the correctly-formatted fullModelId (e.g. "google/gemini-2.0-flash") - // rather than the raw targetModel string. - if (resolution.category === "openrouter") { - if (resolution.wasAutoRouted && resolution.fullModelId) { - return getOpenRouterHandler(resolution.fullModelId); + // Auto-routed to a different target (e.g. bare model -> openrouter vendor/model or litellm@model). + // Recurse with the resolved fullModelId so the correct provider handler is created. + if (resolution.wasAutoRouted && resolution.fullModelId && resolution.fullModelId !== targetModel) { + const handler = getHandler(resolution.fullModelId, invocationMode); + if (handler) { + // Cache under both the original and resolved target + handlerCache.set(targetModel, handler); } - return null; + return handler; } - // When auto-routed (e.g. to LiteLLM), use the resolved fullModelId so that - // resolveRemoteProvider() receives "litellm@gemini-2.0-flash" instead of the - // original bare model name which would match the wrong (native) provider. - const resolveTarget = - resolution.wasAutoRouted && resolution.fullModelId ? resolution.fullModelId : targetModel; + // OpenRouter category from auto-routing: the fullModelId has the vendor prefix + // (e.g. "google/gemini-2.0-flash"). Create via openrouter definition. + if (resolution.category === "openrouter") { + const orModelId = resolution.fullModelId || targetModel; + return createAndCache( + "openrouter", orModelId, targetModel, invocationMode, + ); + } - // If resolver says use direct-api and key is available, create handler + // Direct-api: look up the ProviderDefinition and create a handler if (resolution.category === "direct-api" && resolution.apiKeyAvailable) { - const resolved = resolveRemoteProvider(resolveTarget); - if (!resolved) return null; - - // Skip 'openrouter' provider here - it uses the existing OpenRouterHandler - if (resolved.provider.name === "openrouter") { - return null; // Will fall through to OpenRouterHandler - } - - // Get API key - empty string for providers that don't require auth (like zen/ free models) - const apiKey = resolved.provider.apiKeyEnvVar - ? process.env[resolved.provider.apiKeyEnvVar] || "" + // When auto-routed, use fullModelId for parseModelSpec so it sees + // "litellm@gemini-2.0-flash" instead of bare "gemini-2.0-flash" + const resolveTarget = + resolution.wasAutoRouted && resolution.fullModelId ? resolution.fullModelId : targetModel; + + const parsedTarget = parseModelSpec(resolveTarget); + // Map google -> gemini for RemoteProvider name compat + const remoteName = parsedTarget.provider === "google" ? "gemini" : parsedTarget.provider; + const def = getProviderDefinitionByRemoteName(remoteName); + if (!def) return null; + + // Get API key (empty for providers without auth requirement, like zen free models) + const apiKey = def.apiKeyEnvVar + ? process.env[def.apiKeyEnvVar] || "" : ""; - const handler = createHandlerForProvider({ - provider: resolved.provider, - modelName: resolved.modelName, + const handler = createHandlerForProvider( + def, + parsedTarget.model, apiKey, targetModel, port, - sharedOpts: { isInteractive: options.isInteractive, invocationMode }, - }); - if (!handler) { - return null; // Profile returned null (missing config) or unknown provider - } + { isInteractive: options.isInteractive, invocationMode, summarizeTools: options.summarizeTools }, + ); + if (!handler) return null; - // Cache under both the original targetModel and the resolveTarget (if different) - // so subsequent lookups with either key are served from cache. - remoteProviderHandlers.set(resolveTarget, handler); + // Cache under both keys so subsequent lookups hit the cache + handlerCache.set(targetModel, handler); if (resolveTarget !== targetModel) { - remoteProviderHandlers.set(targetModel, handler); + handlerCache.set(resolveTarget, handler); } return handler; } - // If we get here, either category is not direct-api or key is not available - // Both cases should fall through to OpenRouter or return null + // Local and Poe: resolve via definition + if (resolution.category === "local") { + const def = getProviderDefinitionByRemoteName(resolution.parsed?.provider || ""); + if (!def) return null; + + const handler = createHandlerForProvider( + def, + resolution.modelName, + "", + targetModel, + port, + { + isInteractive: options.isInteractive, + invocationMode, + summarizeTools: options.summarizeTools, + concurrency: resolution.concurrency, + }, + ); + if (handler) handlerCache.set(targetModel, handler); + return handler; + } + return null; }; + /** + * Create a handler for a named provider and cache it. + * Used for OpenRouter (and any future provider that needs creation by name). + */ + const createAndCache = ( + providerName: string, + modelId: string, + cacheKey: string, + invocationMode?: ComposedHandlerOptions["invocationMode"], + ): ModelHandler | null => { + // For OpenRouter: strip @ prefix if present, preserve vendor/ prefix + const parsed = parseModelSpec(modelId); + const effectiveModelId = modelId.includes("@") ? parsed.model : modelId; + + const def = getProviderDefinitionByRemoteName(providerName); + if (!def) return null; + + const apiKey = def.apiKeyEnvVar + ? process.env[def.apiKeyEnvVar] || openrouterApiKey || "" + : ""; + + const handler = createHandlerForProvider( + def, + effectiveModelId, + apiKey, + effectiveModelId, + port, + { isInteractive: options.isInteractive, invocationMode }, + ); + if (handler) { + handlerCache.set(cacheKey, handler); + if (cacheKey !== effectiveModelId) { + handlerCache.set(effectiveModelId, handler); + } + } + return handler; + }; + // Pre-warm LiteLLM model cache for auto-routing (non-blocking) if (process.env.LITELLM_BASE_URL && process.env.LITELLM_API_KEY) { fetchLiteLLMModels(process.env.LITELLM_BASE_URL, process.env.LITELLM_API_KEY) @@ -347,12 +264,14 @@ export async function createProxyServer( const invocationMode = detectInvocationMode(target, wasFromModelMap); - // 2b. Catalog resolution — resolve vendor prefix for OpenRouter and LiteLLM + // 2b. Catalog resolution — resolve vendor prefix for aggregator providers + // (e.g., OpenRouter, LiteLLM) that use vendor-prefixed model names. // This must happen after target is determined but before handler construction. // resolveModelNameSync is synchronous (uses in-memory cache + readFileSync). { const parsedTarget = parseModelSpec(target); - if (parsedTarget.provider === "openrouter" || parsedTarget.provider === "litellm") { + const targetDef = getProviderDefinitionByRemoteName(parsedTarget.provider); + if (targetDef?.preserveVendorPrefix) { const resolution = resolveModelNameSync(parsedTarget.model, parsedTarget.provider); logResolution(parsedTarget.model, resolution, options.quiet); if (resolution.wasResolved) { @@ -371,7 +290,7 @@ export async function createProxyServer( if ( !parsedForFallback.isExplicitProvider && parsedForFallback.provider !== "native-anthropic" && - !isPoeModel(target) + !target.startsWith("poe:") ) { const cacheKey = `fallback:${target}`; if (fallbackHandlerCache.has(cacheKey)) { @@ -387,12 +306,7 @@ export async function createProxyServer( if (chain.length > 0) { const candidates: FallbackCandidate[] = []; for (const route of chain) { - let handler: ModelHandler | null = null; - if (route.provider === "openrouter") { - handler = getOpenRouterHandler(route.modelSpec, invocationMode); - } else { - handler = getRemoteProviderHandler(route.modelSpec, invocationMode); - } + const handler = getHandler(route.modelSpec, invocationMode); if (handler) { candidates.push({ name: route.displayName, handler }); } @@ -416,37 +330,17 @@ export async function createProxyServer( } } - // 3. Check for Poe Model (poe: prefix) - if (isPoeModel(target)) { - const poeHandler = getPoeHandler(target, invocationMode); - if (poeHandler) { - log(`[Proxy] Routing to Poe: ${target}`); - return poeHandler; - } - } - - // 4. Check for Remote Provider (g/, gemini/, oai/, openai/, mmax/, mm/, kimi/, moonshot/, glm/, zhipu/) - const remoteHandler = getRemoteProviderHandler(target, invocationMode); - if (remoteHandler) return remoteHandler; - - // 5. Check for Local Provider (ollama/, lmstudio/, vllm/, or URL) - const localHandler = getLocalProviderHandler(target, invocationMode); - if (localHandler) return localHandler; + // 3. Try all providers via getHandler (Poe, remote, local, OpenRouter all go through here) + const handler = getHandler(target, invocationMode); + if (handler) return handler; - // 6. Native vs OpenRouter Decision - // Models with explicit provider prefix (@) should never fall to native Anthropic handler. - // They were explicitly routed to a provider - if the handler wasn't created above, - // it's because the API key is missing, not because it's a native model. + // 4. Native Anthropic (bare model name, no provider prefix) const hasExplicitProvider = target.includes("@"); const isNative = !target.includes("/") && !hasExplicitProvider; + if (isNative) return nativeHandler; - if (isNative) { - // If we mapped to a native string (unlikely) or passed through - return nativeHandler; - } - - // 7. OpenRouter Handler (default for any model with "/" or explicit provider not matched above) - return getOpenRouterHandler(target, invocationMode); + // 5. OpenRouter fallback (vendor/model or unmatched explicit provider) + return getHandler(`openrouter@${target}`, invocationMode) || nativeHandler; }; const app = new Hono(); diff --git a/packages/macos-bridge/src/routing-middleware.ts b/packages/macos-bridge/src/routing-middleware.ts index f7d2d3ff..c9336c30 100644 --- a/packages/macos-bridge/src/routing-middleware.ts +++ b/packages/macos-bridge/src/routing-middleware.ts @@ -18,8 +18,9 @@ import { LocalModelAdapter } from "../../cli/src/adapters/local-adapter.js"; import { OpenRouterProvider } from "../../cli/src/providers/transport/openrouter.js"; import { OpenRouterAPIFormat } from "../../cli/src/adapters/openrouter-api-format.js"; import { - getRegisteredRemoteProviders, -} from "../../cli/src/providers/remote-provider-registry.js"; + getAllProviders, + toRemoteProvider, +} from "../../cli/src/providers/provider-definitions.js"; import { resolveProvider, } from "../../cli/src/providers/provider-registry.js"; @@ -68,7 +69,9 @@ export class RoutingMiddleware { * Create handler for a model ID using ComposedHandler + Provider + Adapter. */ private createHandlerForModel(model: string): Handler { - const remoteProviders = getRegisteredRemoteProviders(); + const remoteProviders = getAllProviders() + .filter((def) => !def.isLocal && def.baseUrl !== "" && def.name !== "qwen" && def.name !== "native-anthropic") + .map(toRemoteProvider); // Gemini direct API: g/gemini-2.0-flash-exp, gemini/gemini-pro if (model.startsWith("g/") || model.startsWith("gemini/")) {