diff --git a/README.md b/README.md index 002d0c5..0bd342c 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,14 @@ Then set the env var yourself: export OPENAI_BASE_URL=http://localhost:8080/v1 ``` +Optional hardening timeouts (Go duration syntax): + +- `FRUGAL_READ_HEADER_TIMEOUT` (default `5s`) +- `FRUGAL_READ_TIMEOUT` (default `15s`) +- `FRUGAL_WRITE_TIMEOUT` (default `120s`) +- `FRUGAL_IDLE_TIMEOUT` (default `60s`) +- `FRUGAL_MAX_HEADER_BYTES` (default `1048576`) + ### Quality thresholds Control cost vs. quality per request: @@ -121,6 +129,7 @@ headers = {"X-Frugal-Fallback": "gpt-4o,claude-sonnet-4-20250514,gemini-2.5-flas ``` If the routed model errors, Frugal walks the chain. +To bound latency and cost, Frugal attempts at most the first 3 fallback models. ## Supported models diff --git a/cmd/frugal/main.go b/cmd/frugal/main.go index 51cf2d6..1ebbf65 100644 --- a/cmd/frugal/main.go +++ b/cmd/frugal/main.go @@ -4,6 +4,8 @@ import ( "log" "net/http" "os" + "strconv" + "time" "github.com/go-chi/chi/v5" @@ -80,6 +82,10 @@ func main() { // Build classifier and router cls := classifier.NewRuleBased() modelEntries, thresholds := router.BuildTaxonomy(cfg) + modelEntries = filterRegisteredModels(modelEntries, registry) + if len(modelEntries) == 0 { + log.Fatal("no routable models available for registered providers") + } rtr := router.New(modelEntries, thresholds) // Build HTTP handler @@ -104,12 +110,56 @@ func main() { addr = a } + server := newHTTPServer(addr, r) + log.Printf("frugal listening on %s", addr) - if err := http.ListenAndServe(addr, r); err != nil { + if err := server.ListenAndServe(); err != nil { log.Fatalf("server error: %v", err) } } +func newHTTPServer(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: envDurationOrDefault("FRUGAL_READ_HEADER_TIMEOUT", 5*time.Second), + ReadTimeout: envDurationOrDefault("FRUGAL_READ_TIMEOUT", 15*time.Second), + WriteTimeout: envDurationOrDefault("FRUGAL_WRITE_TIMEOUT", 120*time.Second), + IdleTimeout: envDurationOrDefault("FRUGAL_IDLE_TIMEOUT", 60*time.Second), + MaxHeaderBytes: envIntOrDefault("FRUGAL_MAX_HEADER_BYTES", http.DefaultMaxHeaderBytes), + } +} + +func envDurationOrDefault(key string, fallback time.Duration) time.Duration { + value := os.Getenv(key) + if value == "" { + return fallback + } + + parsed, err := time.ParseDuration(value) + if err != nil || parsed <= 0 { + log.Printf("warning: invalid %s=%q, using default %s", key, value, fallback) + return fallback + } + + return parsed +} + +func envIntOrDefault(key string, fallback int) int { + value := os.Getenv(key) + if value == "" { + return fallback + } + + parsed, err := strconv.Atoi(value) + if err != nil || parsed <= 0 { + log.Printf("warning: invalid %s=%q, using default %d", key, value, fallback) + return fallback + } + + return parsed +} + func modelNames(pc config.ProviderConfig) []string { names := make([]string, 0, len(pc.Models)) for name := range pc.Models { @@ -117,3 +167,13 @@ func modelNames(pc config.ProviderConfig) []string { } return names } + +func filterRegisteredModels(entries []router.ModelEntry, registry *provider.Registry) []router.ModelEntry { + filtered := make([]router.ModelEntry, 0, len(entries)) + for _, entry := range entries { + if _, err := registry.Resolve(entry.Name); err == nil { + filtered = append(filtered, entry) + } + } + return filtered +} diff --git a/cmd/frugal/main_test.go b/cmd/frugal/main_test.go new file mode 100644 index 0000000..54b7261 --- /dev/null +++ b/cmd/frugal/main_test.go @@ -0,0 +1,138 @@ +package main + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +type testProvider struct { + name string + models []string +} + +func (p *testProvider) Name() string { return p.name } +func (p *testProvider) Models() []string { return p.models } + +func (p *testProvider) ChatCompletion(_ context.Context, _ string, _ *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + return &types.ChatCompletionResponse{}, nil +} + +func (p *testProvider) ChatCompletionStream(_ context.Context, _ string, _ *types.ChatCompletionRequest) (<-chan provider.StreamChunk, error) { + ch := make(chan provider.StreamChunk) + close(ch) + return ch, nil +} + +func TestFilterRegisteredModels(t *testing.T) { + reg := provider.NewRegistry() + reg.Register(&testProvider{name: "openai", models: []string{"gpt-4o-mini"}}) + + entries := []router.ModelEntry{ + {Name: "gpt-4o-mini", Provider: "openai"}, + {Name: "claude-sonnet-4-20250514", Provider: "anthropic"}, + } + + filtered := filterRegisteredModels(entries, reg) + if got := len(filtered); got != 1 { + t.Fatalf("expected 1 registered model, got %d", got) + } + if filtered[0].Name != "gpt-4o-mini" { + t.Fatalf("expected gpt-4o-mini to remain, got %s", filtered[0].Name) + } +} + +func TestNewHTTPServerDefaults(t *testing.T) { + t.Setenv("FRUGAL_READ_HEADER_TIMEOUT", "") + t.Setenv("FRUGAL_READ_TIMEOUT", "") + t.Setenv("FRUGAL_WRITE_TIMEOUT", "") + t.Setenv("FRUGAL_IDLE_TIMEOUT", "") + t.Setenv("FRUGAL_MAX_HEADER_BYTES", "") + + srv := newHTTPServer(":8080", http.NewServeMux()) + + if srv.ReadHeaderTimeout != 5*time.Second { + t.Fatalf("expected default read header timeout 5s, got %s", srv.ReadHeaderTimeout) + } + if srv.ReadTimeout != 15*time.Second { + t.Fatalf("expected default read timeout 15s, got %s", srv.ReadTimeout) + } + if srv.WriteTimeout != 120*time.Second { + t.Fatalf("expected default write timeout 120s, got %s", srv.WriteTimeout) + } + if srv.IdleTimeout != 60*time.Second { + t.Fatalf("expected default idle timeout 60s, got %s", srv.IdleTimeout) + } + if srv.MaxHeaderBytes != http.DefaultMaxHeaderBytes { + t.Fatalf("expected default max header bytes %d, got %d", http.DefaultMaxHeaderBytes, srv.MaxHeaderBytes) + } +} + +func TestNewHTTPServerEnvOverrides(t *testing.T) { + t.Setenv("FRUGAL_READ_HEADER_TIMEOUT", "6s") + t.Setenv("FRUGAL_READ_TIMEOUT", "20s") + t.Setenv("FRUGAL_WRITE_TIMEOUT", "150s") + t.Setenv("FRUGAL_IDLE_TIMEOUT", "75s") + t.Setenv("FRUGAL_MAX_HEADER_BYTES", "65536") + + srv := newHTTPServer(":8080", http.NewServeMux()) + + if srv.ReadHeaderTimeout != 6*time.Second { + t.Fatalf("expected read header timeout 6s, got %s", srv.ReadHeaderTimeout) + } + if srv.ReadTimeout != 20*time.Second { + t.Fatalf("expected read timeout 20s, got %s", srv.ReadTimeout) + } + if srv.WriteTimeout != 150*time.Second { + t.Fatalf("expected write timeout 150s, got %s", srv.WriteTimeout) + } + if srv.IdleTimeout != 75*time.Second { + t.Fatalf("expected idle timeout 75s, got %s", srv.IdleTimeout) + } + if srv.MaxHeaderBytes != 65536 { + t.Fatalf("expected max header bytes 65536, got %d", srv.MaxHeaderBytes) + } +} + +func TestEnvDurationOrDefaultInvalidValues(t *testing.T) { + const key = "FRUGAL_TIMEOUT_TEST" + + t.Setenv(key, "not-a-duration") + if got := envDurationOrDefault(key, 3*time.Second); got != 3*time.Second { + t.Fatalf("expected fallback for invalid duration, got %s", got) + } + + t.Setenv(key, "0s") + if got := envDurationOrDefault(key, 3*time.Second); got != 3*time.Second { + t.Fatalf("expected fallback for zero duration, got %s", got) + } + + t.Setenv(key, "-2s") + if got := envDurationOrDefault(key, 3*time.Second); got != 3*time.Second { + t.Fatalf("expected fallback for negative duration, got %s", got) + } +} + +func TestEnvIntOrDefaultInvalidValues(t *testing.T) { + const key = "FRUGAL_INT_TEST" + + t.Setenv(key, "not-an-int") + if got := envIntOrDefault(key, 1234); got != 1234 { + t.Fatalf("expected fallback for invalid int, got %d", got) + } + + t.Setenv(key, "0") + if got := envIntOrDefault(key, 1234); got != 1234 { + t.Fatalf("expected fallback for zero int, got %d", got) + } + + t.Setenv(key, "-10") + if got := envIntOrDefault(key, 1234); got != 1234 { + t.Fatalf("expected fallback for negative int, got %d", got) + } +} diff --git a/cmd/frugal/wrap.go b/cmd/frugal/wrap.go index d2c55ec..7c3b305 100644 --- a/cmd/frugal/wrap.go +++ b/cmd/frugal/wrap.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io" "log" "net" "net/http" @@ -47,6 +48,11 @@ func runWrap(configPath string, args []string) int { cls := classifier.NewRuleBased() modelEntries, thresholds := router.BuildTaxonomy(cfg) + modelEntries = filterRegisteredModels(modelEntries, registry) + if len(modelEntries) == 0 { + fmt.Fprintln(os.Stderr, "frugal: no routable models available for registered providers") + return 1 + } rtr := router.New(modelEntries, thresholds) h := proxy.NewHandler(cls, rtr, registry) @@ -77,7 +83,11 @@ func runWrap(configPath string, args []string) int { }() // Wait for proxy to be ready - waitForReady(fmt.Sprintf("http://127.0.0.1:%d/health", port)) + if err := waitForReady(fmt.Sprintf("http://127.0.0.1:%d/health", port), 2*time.Second); err != nil { + fmt.Fprintf(os.Stderr, "frugal: proxy failed health check: %v\n", err) + server.Close() + return 1 + } fmt.Fprintf(os.Stderr, "frugal: proxy running on :%d → routing across %d models\n", port, len(registry.AllModels())) @@ -144,13 +154,21 @@ func injectEnv(environ []string, baseURL string) []string { return out } -func waitForReady(url string) { - for i := 0; i < 50; i++ { - resp, err := http.Get(url) +func waitForReady(url string, timeout time.Duration) error { + client := &http.Client{Timeout: 200 * time.Millisecond} + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + resp, err := client.Get(url) if err == nil { + io.Copy(io.Discard, resp.Body) resp.Body.Close() - return + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + return nil + } } - time.Sleep(10 * time.Millisecond) + time.Sleep(25 * time.Millisecond) } + + return fmt.Errorf("timed out waiting for %s after %s", url, timeout) } diff --git a/cmd/frugal/wrap_test.go b/cmd/frugal/wrap_test.go new file mode 100644 index 0000000..f564e9d --- /dev/null +++ b/cmd/frugal/wrap_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestWaitForReadyReturnsAfterHealthyStatus(t *testing.T) { + var attempts int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(&attempts, 1) < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("not ready")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer ts.Close() + + if err := waitForReady(ts.URL, 500*time.Millisecond); err != nil { + t.Fatalf("expected readiness check to succeed, got error: %v", err) + } +} + +func TestWaitForReadyTimesOutWhenUnreachable(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to reserve local port: %v", err) + } + addr := ln.Addr().String() + _ = ln.Close() + + url := fmt.Sprintf("http://%s/health", addr) + if err := waitForReady(url, 150*time.Millisecond); err == nil { + t.Fatalf("expected timeout error for unreachable endpoint") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 146008a..8839617 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "bytes" "fmt" "os" @@ -53,7 +54,9 @@ func Load(path string) (*Config, error) { } var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + if err := dec.Decode(&cfg); err != nil { return nil, fmt.Errorf("parsing config: %w", err) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 59f261f..f045ca9 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" ) @@ -77,3 +78,44 @@ func TestLoad_MissingFile(t *testing.T) { t.Error("expected error for missing file") } } + +func TestLoad_RejectsUnknownFields(t *testing.T) { + content := ` +providers: + openai: + api_key_env: OPENAI_API_KEY + base_url: https://api.openai.com/v1 + models: + gpt-4o: + cost_per_1k_input: 0.0025 + cost_per_1k_output: 0.01 + capabilities: + reasoning: 0.95 + coding: 0.92 + creative: 0.90 + instruction_following: 0.95 + tool_use: true + json_mode: true + max_context: 128000 + typo_field: true +quality_thresholds: + balanced: + min_reasoning: 0.70 + min_coding: 0.68 + min_creative: 0.65 + min_instruction_following: 0.72 +` + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for unknown field") + } + if !strings.Contains(err.Error(), "typo_field") { + t.Fatalf("expected unknown field error to mention typo_field, got: %v", err) + } +} diff --git a/internal/provider/anthropic/anthropic.go b/internal/provider/anthropic/anthropic.go index e71ee34..a2d7c18 100644 --- a/internal/provider/anthropic/anthropic.go +++ b/internal/provider/anthropic/anthropic.go @@ -1,7 +1,6 @@ package anthropic import ( - "bufio" "bytes" "context" "encoding/json" @@ -259,7 +258,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * defer resp.Body.Close() chunkID := fmt.Sprintf("chatcmpl-%s", ar.Model) - scanner := bufio.NewScanner(resp.Body) + scanner := provider.NewSSEScanner(resp.Body) for scanner.Scan() { line := scanner.Text() diff --git a/internal/provider/google/google.go b/internal/provider/google/google.go index 887c1b2..417d859 100644 --- a/internal/provider/google/google.go +++ b/internal/provider/google/google.go @@ -1,7 +1,6 @@ package google import ( - "bufio" "bytes" "context" "encoding/json" @@ -244,7 +243,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * defer close(ch) defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) + scanner := provider.NewSSEScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { diff --git a/internal/provider/openai/openai.go b/internal/provider/openai/openai.go index 86e207b..0c286c9 100644 --- a/internal/provider/openai/openai.go +++ b/internal/provider/openai/openai.go @@ -1,7 +1,6 @@ package openai import ( - "bufio" "bytes" "context" "encoding/json" @@ -103,7 +102,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * defer close(ch) defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) + scanner := provider.NewSSEScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { diff --git a/internal/provider/openai/openai_stream_test.go b/internal/provider/openai/openai_stream_test.go new file mode 100644 index 0000000..d8da943 --- /dev/null +++ b/internal/provider/openai/openai_stream_test.go @@ -0,0 +1,52 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/frugalsh/frugal/internal/types" +) + +func TestChatCompletionStream_AllowsLargeSSELines(t *testing.T) { + large := strings.Repeat("x", 70*1024) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"%s\"}}]}\n\n", large) + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer ts.Close() + + p := New("test-key", ts.URL, []string{"gpt-4o-mini"}) + p.client = ts.Client() + + ch, err := p.ChatCompletionStream(context.Background(), "gpt-4o-mini", &types.ChatCompletionRequest{}) + if err != nil { + t.Fatalf("ChatCompletionStream returned error: %v", err) + } + + var gotData, gotDone bool + for chunk := range ch { + if chunk.Err != nil { + t.Fatalf("stream chunk error: %v", chunk.Err) + } + if chunk.Data != nil { + gotData = true + } + if chunk.Done { + gotDone = true + } + } + + if !gotData { + t.Fatal("expected at least one data chunk") + } + if !gotDone { + t.Fatal("expected done chunk") + } +} + diff --git a/internal/provider/sse.go b/internal/provider/sse.go new file mode 100644 index 0000000..5df7803 --- /dev/null +++ b/internal/provider/sse.go @@ -0,0 +1,18 @@ +package provider + +import ( + "bufio" + "io" +) + +const maxSSELineBytes = 1024 * 1024 // 1 MiB + +// NewSSEScanner returns a scanner configured for larger-than-default SSE lines. +// Provider APIs can emit large JSON chunks that exceed bufio.Scanner's 64 KiB default. +func NewSSEScanner(r io.Reader) *bufio.Scanner { + s := bufio.NewScanner(r) + buf := make([]byte, 64*1024) + s.Buffer(buf, maxSSELineBytes) + return s +} + diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 488df0a..b01b147 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net/http" + "strings" "sync" "time" @@ -14,6 +15,8 @@ import ( "github.com/frugalsh/frugal/internal/types" ) +const maxFallbackAttempts = 3 + // Handler serves the OpenAI-compatible API endpoints. type Handler struct { classifier classifier.Classifier @@ -107,7 +110,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, prov p resp, err := prov.ChatCompletion(r.Context(), decision.SelectedModel, req) if err != nil { // Try fallback chain - for _, fb := range fallbacks { + for _, fb := range boundedFallbacks(fallbacks) { fbProv, fbErr := h.registry.Resolve(fb) if fbErr != nil { continue @@ -132,7 +135,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, prov prov ch, err := prov.ChatCompletionStream(r.Context(), decision.SelectedModel, req) if err != nil { // Try fallback chain - for _, fb := range fallbacks { + for _, fb := range boundedFallbacks(fallbacks) { fbProv, fbErr := h.registry.Resolve(fb) if fbErr != nil { continue @@ -154,6 +157,28 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, prov prov } } +func boundedFallbacks(fallbacks []string) []string { + if len(fallbacks) == 0 { + return nil + } + + bounded := make([]string, 0, maxFallbackAttempts) + for _, fb := range fallbacks { + if len(bounded) >= maxFallbackAttempts { + break + } + + trimmed := strings.TrimSpace(fb) + if trimmed == "" { + continue + } + + bounded = append(bounded, trimmed) + } + + return bounded +} + // ListModels handles GET /v1/models func (h *Handler) ListModels(w http.ResponseWriter, r *http.Request) { models := h.registry.AllModels() diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index b6719af..2832385 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -4,10 +4,12 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -22,13 +24,29 @@ type mockProvider struct { name string models []string response *types.ChatCompletionResponse + chatErr error streamErr error + + mu sync.Mutex + chatCalls int + streamCalls int + lastChatModel string + lastStreamModel string } func (m *mockProvider) Name() string { return m.name } func (m *mockProvider) Models() []string { return m.models } func (m *mockProvider) ChatCompletion(ctx context.Context, model string, req *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + m.mu.Lock() + m.chatCalls++ + m.lastChatModel = model + m.mu.Unlock() + + if m.chatErr != nil { + return nil, m.chatErr + } + if m.response != nil { return m.response, nil } @@ -51,6 +69,11 @@ func (m *mockProvider) ChatCompletion(ctx context.Context, model string, req *ty } func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, req *types.ChatCompletionRequest) (<-chan provider.StreamChunk, error) { + m.mu.Lock() + m.streamCalls++ + m.lastStreamModel = model + m.mu.Unlock() + if m.streamErr != nil { return nil, m.streamErr } @@ -89,6 +112,18 @@ func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, r return ch, nil } +func (m *mockProvider) ChatCallCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.chatCalls +} + +func (m *mockProvider) LastChatModel() string { + m.mu.Lock() + defer m.mu.Unlock() + return m.lastChatModel +} + func setupHandler() (*Handler, *httptest.Server) { reg := provider.NewRegistry() mock := &mockProvider{ @@ -332,3 +367,55 @@ func mustMarshalJSON(v any) json.RawMessage { b, _ := json.Marshal(v) return b } + +func TestChatCompletions_FallbackAttemptsAreBounded(t *testing.T) { + reg := provider.NewRegistry() + + failing := &mockProvider{name: "failing", models: []string{"primary", "fb1", "fb2", "fb3", "fb4"}, chatErr: errors.New("boom")} + reg.Register(failing) + + models := []router.ModelEntry{{ + Name: "primary", Provider: "failing", + CostPer1KInput: 0.0001, CostPer1KOutput: 0.0002, + Reasoning: 0.8, Coding: 0.8, Creative: 0.8, InstructFollowing: 0.8, + ToolUse: true, JSONMode: true, MaxContext: 128000, + }} + thresholds := map[string]router.Threshold{ + "balanced": {MinReasoning: 0.1, MinCoding: 0.1, MinCreative: 0.1, MinInstructFollowing: 0.1}, + } + + h := NewHandler(classifier.NewRuleBased(), router.New(models, thresholds), reg) + mux := http.NewServeMux() + mux.HandleFunc("POST /v1/chat/completions", h.ChatCompletions) + ts := httptest.NewServer(HeaderExtractionMiddleware(mux)) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("Hello")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Fallback", "fb1,fb2,fb3,fb4") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 502, got %d: %s", resp.StatusCode, string(b)) + } + + // 1 primary attempt + maxFallbackAttempts fallbacks. + wantCalls := 1 + maxFallbackAttempts + if got := failing.ChatCallCount(); got != wantCalls { + t.Fatalf("expected %d total attempts, got %d", wantCalls, got) + } + + if got := failing.LastChatModel(); got != "fb3" { + t.Fatalf("expected last attempted fallback model fb3, got %s", got) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 7d8f18a..880ddfc 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -19,7 +19,7 @@ func New(models []ModelEntry, thresholds map[string]Threshold) *Router { // Route selects a model based on query features and quality threshold. func (r *Router) Route(features types.QueryFeatures, quality types.QualityThreshold, fallbacks []string) types.RoutingDecision { - threshold := r.thresholds[string(quality)] + threshold := r.thresholdForQuality(quality) // Filter candidates var candidates []ModelEntry @@ -32,8 +32,8 @@ func (r *Router) Route(features types.QueryFeatures, quality types.QualityThresh if len(candidates) == 0 { // Fallback: relax threshold to balanced, then cost - for _, fallbackQuality := range []string{"balanced", "cost"} { - ft := r.thresholds[fallbackQuality] + for _, fallbackQuality := range []types.QualityThreshold{types.QualityBalanced, types.QualityCost} { + ft := r.thresholdForQuality(fallbackQuality) for _, m := range r.models { if r.meetsRequirements(m, features, ft) { candidates = append(candidates, m) @@ -76,6 +76,16 @@ func (r *Router) Route(features types.QueryFeatures, quality types.QualityThresh } } +func (r *Router) thresholdForQuality(quality types.QualityThreshold) Threshold { + if t, ok := r.thresholds[string(quality)]; ok { + return t + } + if t, ok := r.thresholds[string(types.QualityBalanced)]; ok { + return t + } + return Threshold{} +} + func (r *Router) meetsRequirements(m ModelEntry, f types.QueryFeatures, t Threshold) bool { // Hard requirements if f.RequiresToolUse && !m.ToolUse { diff --git a/internal/router/router_test.go b/internal/router/router_test.go index 99615b5..f746007 100644 --- a/internal/router/router_test.go +++ b/internal/router/router_test.go @@ -162,3 +162,35 @@ func TestRoute_NoModels_ReturnsEmpty(t *testing.T) { t.Errorf("expected empty model when no models available, got %s", d.SelectedModel) } } + +func TestRoute_UnknownQualityDefaultsToBalancedThreshold(t *testing.T) { + r := New(testModels(), testThresholds()) + + features := types.QueryFeatures{ + EstimatedInputTokens: 100, + EstimatedOutputTokens: 100, + } + + d := r.Route(features, types.QualityThreshold("unknown"), nil) + + if d.SelectedModel != "mid-model" { + t.Errorf("expected balanced fallback model mid-model for unknown quality, got %s", d.SelectedModel) + } +} + +func TestRoute_MissingRequestedThresholdFallsBackToBalanced(t *testing.T) { + thresholds := testThresholds() + delete(thresholds, "high") + r := New(testModels(), thresholds) + + features := types.QueryFeatures{ + EstimatedInputTokens: 100, + EstimatedOutputTokens: 100, + } + + d := r.Route(features, types.QualityHigh, nil) + + if d.SelectedModel != "mid-model" { + t.Errorf("expected balanced fallback model mid-model when high threshold missing, got %s", d.SelectedModel) + } +} diff --git a/internal/types/routing.go b/internal/types/routing.go index bb815c9..956479e 100644 --- a/internal/types/routing.go +++ b/internal/types/routing.go @@ -1,5 +1,7 @@ package types +import "strings" + // QualityThreshold controls how aggressively Frugal routes to cheaper models. type QualityThreshold string @@ -11,7 +13,7 @@ const ( // ParseQualityThreshold parses a string into a QualityThreshold, defaulting to balanced. func ParseQualityThreshold(s string) QualityThreshold { - switch s { + switch strings.ToLower(strings.TrimSpace(s)) { case "high": return QualityHigh case "cost": diff --git a/internal/types/routing_test.go b/internal/types/routing_test.go new file mode 100644 index 0000000..d434808 --- /dev/null +++ b/internal/types/routing_test.go @@ -0,0 +1,27 @@ +package types + +import "testing" + +func TestParseQualityThreshold_NormalizedInputs(t *testing.T) { + tests := []struct { + name string + in string + want QualityThreshold + }{ + {name: "exact high", in: "high", want: QualityHigh}, + {name: "uppercase high", in: "HIGH", want: QualityHigh}, + {name: "spaced high", in: " high ", want: QualityHigh}, + {name: "mixed-case cost", in: "CoSt", want: QualityCost}, + {name: "spaced balanced", in: " balanced ", want: QualityBalanced}, + {name: "unknown defaults to balanced", in: "fast", want: QualityBalanced}, + {name: "empty defaults to balanced", in: "", want: QualityBalanced}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParseQualityThreshold(tt.in); got != tt.want { + t.Fatalf("ParseQualityThreshold(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +}