diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..d2e3b13 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,53 @@ +name: Bug report +description: Something is broken, crashing, or returning the wrong result +labels: ["bug"] +body: + - type: textarea + id: what-happened + attributes: + label: What happened + description: What did you do, what did you expect, what did you get instead? + placeholder: | + 1. I ran `frugal serve` + 2. I sent a chat completion with `X-Frugal-Use-Case: research-synthesis` + 3. I expected routing to Claude Sonnet 4, got GPT-4o-mini + validations: + required: true + + - type: input + id: version + attributes: + label: frugal version + description: Output of `frugal --version` + placeholder: "v0.0.1" + validations: + required: true + + - type: dropdown + id: os + attributes: + label: OS + options: + - macOS (Apple silicon) + - macOS (Intel) + - Linux (amd64) + - Linux (arm64) + - Other (describe below) + validations: + required: true + + - type: textarea + id: logs + attributes: + label: Relevant logs or explain output + description: | + Run with `FRUGAL_LOG_LEVEL=debug` and/or hit `/v1/routing/explain` after + the failing request. Paste relevant lines here. Redact any API keys. + render: text + + - type: textarea + id: config + attributes: + label: Config (optional) + description: Any non-default env vars or edits to `~/.frugal/config/models.yaml` + render: text diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..4a0b0f9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Security vulnerability + url: https://github.com/brainsparker/frugal/security/advisories/new + about: Report security issues privately via GitHub Security Advisories. See SECURITY.md for the full disclosure policy. + - name: Question or discussion + url: https://github.com/brainsparker/frugal/discussions + about: For open-ended questions, design discussions, or "how do I…" — use Discussions, not Issues. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..7d6e7ef --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,35 @@ +name: Feature request +description: Suggest a new capability, use case, provider, or quality-of-life improvement +labels: ["enhancement"] +body: + - type: textarea + id: problem + attributes: + label: The problem + description: What can't you do today, or what's harder than it should be? + validations: + required: true + + - type: textarea + id: proposal + attributes: + label: Proposed direction + description: | + Rough shape of the solution. Bullet points are fine. Concrete is better + than abstract — API shape, CLI flag, config key, etc. + validations: + required: true + + - type: textarea + id: alternatives + attributes: + label: Alternatives considered + description: What workarounds exist today? Why are they insufficient? + + - type: textarea + id: context + attributes: + label: Context + description: | + Anything else worth knowing — use case, prior art in other tools, + whether you'd be open to contributing a PR. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..ae17162 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,28 @@ + + +## What + + + +## Why + + + +## How to verify + + + +## Notes for reviewers + + diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..67cfab3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,51 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: "1.25" + cache: true + + - name: Build + run: go build ./... + + - name: Vet + run: go vet ./... + + - name: Test + run: go test -race ./... + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: "1.25" + cache: true + + - name: Install staticcheck + run: go install honnef.co/go/tools/cmd/staticcheck@latest + + - name: Staticcheck + run: staticcheck ./... + + - name: Install govulncheck + run: go install golang.org/x/vuln/cmd/govulncheck@latest + + - name: Govulncheck + run: govulncheck ./... diff --git a/.github/workflows/dco.yml b/.github/workflows/dco.yml new file mode 100644 index 0000000..a3a1034 --- /dev/null +++ b/.github/workflows/dco.yml @@ -0,0 +1,16 @@ +name: DCO + +on: + pull_request: + types: [opened, synchronize, reopened] + +permissions: + contents: read + pull-requests: write + +jobs: + dco: + runs-on: ubuntu-latest + steps: + - name: Check Developer Certificate of Origin sign-off + uses: tim-actions/dco@v1.1.0 diff --git a/.github/workflows/install-smoke.yml b/.github/workflows/install-smoke.yml new file mode 100644 index 0000000..e6c8df0 --- /dev/null +++ b/.github/workflows/install-smoke.yml @@ -0,0 +1,110 @@ +name: install-smoke + +# Runs the real, public installer end-to-end on fresh runners after every +# release. Finds the "works on my Mac" class of regressions before a user does. +# +# Also runs on workflow_dispatch so you can re-test without cutting a release. +on: + release: + types: [published] + workflow_dispatch: + +permissions: + contents: read + +jobs: + curl-pipe-sh: + name: ${{ matrix.os }} / curl | bash + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Install via public URL + env: + FRUGAL_YES: "1" + # Authenticate the releases/latest API call so shared runner IPs + # don't 403 against the 60/hr anonymous rate limit. + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + curl -fsSL https://frugal.sh/install | bash + + - name: Binary runs and reports a version + run: | + set -euo pipefail + "$HOME/.frugal/bin/frugal" --version + + - name: Shell rc block was written + run: | + set -euo pipefail + found=0 + for rc in "$HOME/.zshrc" "$HOME/.bashrc" "$HOME/.bash_profile"; do + if [ -f "$rc" ] && grep -qxF "# >>> frugal.sh >>>" "$rc"; then + echo "marker block present in $rc" + found=1 + break + fi + done + if [ "$found" -ne 1 ]; then + echo "no shell rc got the frugal.sh marker block" >&2 + exit 1 + fi + + - name: Uninstall is clean + run: | + set -euo pipefail + curl -fsSL https://frugal.sh/install | bash -s uninstall + if [ -d "$HOME/.frugal" ]; then + echo "$HOME/.frugal still exists after uninstall" >&2 + exit 1 + fi + for rc in "$HOME/.zshrc" "$HOME/.bashrc" "$HOME/.bash_profile"; do + if [ -f "$rc" ] && grep -qxF "# >>> frugal.sh >>>" "$rc"; then + echo "marker block still in $rc after uninstall" >&2 + exit 1 + fi + done + + pinned-version: + name: ${{ matrix.os }} / FRUGAL_VERSION pin + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Derive tag + id: tag + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set -euo pipefail + if [ "${{ github.event_name }}" = "release" ]; then + echo "value=${{ github.event.release.tag_name }}" >> "$GITHUB_OUTPUT" + else + # Manual re-runs pin to whatever is currently latest. + latest="$(curl -fsSL -H "Authorization: Bearer $GITHUB_TOKEN" \ + https://api.github.com/repos/${{ github.repository }}/releases/latest \ + | grep '"tag_name"' | head -n1 | sed -E 's/.*"tag_name"[[:space:]]*:[[:space:]]*"([^"]+)".*/\1/')" + echo "value=$latest" >> "$GITHUB_OUTPUT" + fi + + - name: Install pinned version + env: + FRUGAL_YES: "1" + FRUGAL_VERSION: ${{ steps.tag.outputs.value }} + run: | + curl -fsSL https://frugal.sh/install | bash + + - name: Binary reports the pinned version + run: | + set -euo pipefail + got="$("$HOME/.frugal/bin/frugal" --version)" + want="${{ steps.tag.outputs.value }}" + echo "got: $got" + echo "want: $want" + case "$got" in + *"$want"*) echo "version matches" ;; + *) echo "version mismatch" >&2; exit 1 ;; + esac diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c9433e5..687324f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,6 +7,7 @@ on: permissions: contents: write + id-token: write # required for keyless cosign via GitHub OIDC jobs: release: @@ -16,7 +17,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version: "1.25" - name: Run tests run: go test ./... @@ -24,6 +25,24 @@ jobs: - name: Build release binaries run: make release + - name: Install cosign + uses: sigstore/cosign-installer@v3 + + - name: Sign binaries and checksums (keyless) + run: | + set -euo pipefail + cd dist + for f in frugal-* SHA256SUMS; do + cosign sign-blob --yes --bundle "${f}.sig" "${f}" + done + + - name: Generate SBOMs (CycloneDX) + uses: anchore/sbom-action@v0 + with: + path: . + format: cyclonedx-json + output-file: dist/frugal.cdx.json + - name: Create GitHub release uses: softprops/action-gh-release@v2 with: diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d25379 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# Binaries +/bin/ +/dist/ + +# Go build outputs and local envs +/frugal +*.test + +# Wrangler (Cloudflare Workers Assets) local state. +# The committed wrangler.jsonc is the source of truth; everything below is +# generated per-machine or per-CI-run. +.wrangler/ +.dev.vars +wrangler-*.log +node_modules/ diff --git a/ADOPTERS.md b/ADOPTERS.md new file mode 100644 index 0000000..03b1833 --- /dev/null +++ b/ADOPTERS.md @@ -0,0 +1,9 @@ +# Adopters + +Teams and individuals running Frugal in production. If that's you and you're willing to be listed, open a PR adding yourself. + +| Organization / Handle | Use case | +|---|---| +| _(be the first)_ | | + +Interested in being a design partner — running Frugal on a real workload with hands-on support in exchange for feedback — open an issue titled `design partner: `. diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..e838dd8 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,89 @@ ++++ +version = "2.1" +aliases = ["/version/2/1"] +reportingPlaceholder = "brian@you.com" ++++ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at brian@you.com. All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..5cd55fe --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,35 @@ +# Contributing + +Thanks for your interest. Frugal is small and focused — contributions that sharpen the existing wedge are more welcome than ones that expand scope. + +## Before you open a PR + +- Open an issue first for anything larger than a bug fix or docs tweak. +- Keep changes minimal and scoped to the stated problem. +- Add a test for new behavior. Run `make test` before pushing. +- Match recent commit style (`fix:`, `feat:`, `chore:`, `test:`, `docs:`). +- Sign off every commit with the Developer Certificate of Origin: append a + `Signed-off-by: Your Name ` trailer (or use `git commit -s`). + CI blocks PRs missing a sign-off on any commit. See + [developercertificate.org](https://developercertificate.org/) for the text. + +## License + +Frugal is licensed under [BUSL 1.1](./LICENSE). By contributing, you agree your contributions are licensed under the same terms. Self-hosted and internal commercial use is permitted; offering Frugal as a competing hosted routing service is not. See the [BUSL FAQ](./LICENSE-BUSL-FAQ.md) for plain-English details. + +## What we're looking for + +- Bug fixes with regression tests +- Provider integrations (new models, new endpoints) that fit the existing config schema +- Benchmark reproducers — real workloads worth adding to the eval harness +- Documentation improvements that tighten claims (remove hand-waves, add measurements) + +## What to skip for now + +- Hosted control plane / multi-tenancy features +- ML-based intent classification (the heuristic classifier is sufficient while we validate) +- New policy knobs or routing tiers without a documented use case + +## Reporting security issues + +Please open a [private security advisory](https://github.com/brainsparker/frugal/security/advisories/new) rather than a public issue. diff --git a/Dockerfile b/Dockerfile index d59fd27..25f8523 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.23-alpine AS builder +FROM golang:1.25-alpine AS builder WORKDIR /app COPY go.mod go.sum ./ diff --git a/LICENSE b/LICENSE index 421db89..46c8d02 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,90 @@ -MIT License - -Copyright (c) 2026 sparker - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +License text copyright (c) 2020 MariaDB Corporation Ab, All Rights Reserved. +"Business Source License" is a trademark of MariaDB Corporation Ab. + +Parameters + +Licensor: Brian Sparker +Licensed Work: Frugal. The Licensed Work is (c) 2026 Brian Sparker. +Additional Use Grant: You may make production use of the Licensed Work, provided + Your use does not include offering the Licensed Work to third + parties on a hosted or embedded basis in order to compete with + Licensor's paid version(s) of the Licensed Work. For purposes + of this license: + + A "competitive offering" is a Product that is offered to third + parties on a paid basis, including through paid support + arrangements, that significantly overlaps with the capabilities + of Licensor's paid version(s) of the Licensed Work. If Your + Product is not a competitive offering when You first make it + generally available, it will not become a competitive offering + later due to Licensor releasing a new version of the Licensed + Work with additional capabilities. In addition, Products that + are not provided on a paid basis are not competitive. + + "Product" means software that is offered to end users to manage + in their own environments or offered as a service on a hosted + basis. + + "Embedded" means including the source code or executable code + from the Licensed Work in a competitive offering. "Embedded" + also means packaging the competitive offering in such a way + that the Licensed Work must be accessed or downloaded for the + competitive offering to operate. + + Hosting or using the Licensed Work for internal purposes within + an organization is not considered a competitive offering. + Licensor considers your organization to include all of your + affiliates under common control. + + For binding interpretive guidance on using the Licensed Work + under the Business Source License, please see LICENSE-BUSL-FAQ.md. +Change Date: Four years from the date the Licensed Work is published. +Change License: Apache License, Version 2.0 + +For information about alternative licensing arrangements for the Licensed Work, +please contact licensing@frugal.sh. + +Notice + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. diff --git a/LICENSE-BUSL-FAQ.md b/LICENSE-BUSL-FAQ.md new file mode 100644 index 0000000..f761d19 --- /dev/null +++ b/LICENSE-BUSL-FAQ.md @@ -0,0 +1,76 @@ +# BUSL 1.1 FAQ + +Frugal is distributed under the [Business Source License 1.1](./LICENSE). +This FAQ is a plain-English explanation of the practical effect. It is not +legal advice. The [LICENSE](./LICENSE) is the authoritative text. + +## TL;DR + +- Run Frugal yourself: **allowed** (production, internal, commercial). +- Modify Frugal and run it yourself: **allowed**. +- Build a product that uses Frugal internally: **allowed**. +- Sell Frugal (or a fork) as a hosted LLM-routing service that competes with + frugal.sh: **not allowed** without a commercial license. +- Four years after a given version is published, that version becomes + Apache 2.0. + +## What is the Business Source License? + +BUSL 1.1 is a source-available license. The source is public, you can read it, +modify it, and run it yourself, but one specific commercial use is carved +out for a limited time. + +That carve-out for Frugal is: **you can't take Frugal and resell it as a hosted +routing service that competes with Frugal's hosted product**. Everything else +is fair game. + +## What can I do today? + +Yes, you can: + +- Download, build, and run Frugal on your own machines for any purpose, + including production and commercial use inside your own company. +- Use Frugal as a client-side tool (`frugal python app.py`) for routing your + team's LLM traffic. +- Fork the repo, modify it, and deploy your fork internally. +- Use Frugal's output (completions) in commercial products. +- Include Frugal as a dependency of a larger product, so long as your product + is not itself a competing hosted routing service. +- Run Frugal as an internal shared service within your organization, including + affiliates under common control. + +Not without a commercial license: + +- Operate a paid, third-party-facing service that offers LLM request routing + substantially similar to frugal.sh. (Running a free service is fine. + Running a paid service that happens to use Frugal is fine if LLM routing + is not the thing you're selling. What's restricted is reselling the + routing itself.) +- Embed Frugal's code or binaries inside a competing product in a way that + requires Frugal to operate. + +## When does it become fully open source? + +Each version of Frugal converts to the Apache License 2.0 four years after +that specific version is first published. Older versions therefore convert +first; the latest version is always under BUSL until its own four-year clock +runs out. + +## Why not MIT or Apache 2.0 up front? + +Frugal's taxonomy and cost router are the product. An MIT license would let +a larger cloud or AI company take the whole thing, run it as a managed +service, and undercut the ability of Frugal's authors to sustain the project. +BUSL 1.1 keeps the code open to individual developers and companies while +protecting the commercial path. The four-year conversion guarantees that the +code does eventually become fully OSS-licensed. + +## What if I need a commercial license? + +Email `licensing@frugal.sh`. + +## Where can I read more about BUSL 1.1? + +- [Official text at mariadb.com/bsl11](https://mariadb.com/bsl11/) +- [HashiCorp's BUSL FAQ](https://www.hashicorp.com/license-faq) — the + phrasing in Frugal's Additional Use Grant is modeled on theirs. diff --git a/Makefile b/Makefile index 7cf939d..0cb9b85 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,10 @@ .PHONY: build run test clean release VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +LDFLAGS := -ldflags "-X main.buildVersion=$(VERSION)" build: - go build -o bin/frugal ./cmd/frugal + go build $(LDFLAGS) -o bin/frugal ./cmd/frugal run: build ./bin/frugal @@ -16,8 +17,9 @@ clean: release: clean mkdir -p dist - GOOS=darwin GOARCH=arm64 go build -o dist/frugal-darwin-arm64 ./cmd/frugal - GOOS=darwin GOARCH=amd64 go build -o dist/frugal-darwin-amd64 ./cmd/frugal - GOOS=linux GOARCH=arm64 go build -o dist/frugal-linux-arm64 ./cmd/frugal - GOOS=linux GOARCH=amd64 go build -o dist/frugal-linux-amd64 ./cmd/frugal + GOOS=darwin GOARCH=arm64 go build $(LDFLAGS) -o dist/frugal-darwin-arm64 ./cmd/frugal + GOOS=darwin GOARCH=amd64 go build $(LDFLAGS) -o dist/frugal-darwin-amd64 ./cmd/frugal + GOOS=linux GOARCH=arm64 go build $(LDFLAGS) -o dist/frugal-linux-arm64 ./cmd/frugal + GOOS=linux GOARCH=amd64 go build $(LDFLAGS) -o dist/frugal-linux-amd64 ./cmd/frugal + cd dist && shasum -a 256 frugal-* > SHA256SUMS @echo "built $(VERSION) binaries in dist/" diff --git a/README.md b/README.md index 002d0c5..6cb9848 100644 --- a/README.md +++ b/README.md @@ -1,61 +1,164 @@ # frugal -**Open-source LLM proxy that routes every request to the cheapest model that won't compromise quality.** +**Open-source AI toolchain cost optimizer. Route every prompt to the cheapest model + toolchain bundle that won't compromise quality.** + +[frugal.sh](https://frugal.sh) · [GitHub](https://github.com/brainsparker/frugal) No account. No code changes. One command. ```bash -curl -fsSL https://frugal.sh/install | sh +curl -fsSL https://frugal.sh/install | bash ``` ```bash frugal python my_app.py ``` -That's it. Frugal starts a local proxy, injects `OPENAI_BASE_URL`, runs your command, and shuts down when it exits. Your app doesn't change. Your API keys stay local. Your bill drops 40-70%. +Frugal wraps any command with a local OpenAI-compatible proxy, classifies each +request by use case, and routes to the cheapest (model + toolchain) bundle that +clears your quality bar. Your app doesn't change. Your API keys stay local. + +--- + +## Why Frugal isn't just another model router + +A model alone isn't the product — the use case is. Legal research wants a strong +reasoner *and* good web search. Code work wants a code-aware model *and* targeted +retrieval. Structured extraction wants the smallest model that passes the schema +and nothing else. + +Frugal's wedge: classify each prompt's use case, then route to the cheapest +**bundle** (model × toolchain) that clears the quality bar for that use case. +Every bundle is grounded in the eval harness — routing isn't opinion, it's what +the data says wins for your workload. + +| Concept | What it is | +|---|---| +| **Capability** | A primitive: chat, web search, reranking, content extraction, browser. | +| **Use case** | A named class of work (`research-synthesis`, `code-dev`, `factual-qa`, `structured-extraction`). Ships with a labeled benchmark workload. | +| **Bundle** | The recommended (capability → provider) map for a use case at a quality tier. | + +## How much does it actually save? + +Run the benchmark on your own keys: + +```bash +frugal bench --quality balanced --out bench.md +``` + +Frugal runs every problem in `config/workloads/starter.yaml` twice — once +through the router (cheapest model that clears your quality bar) and once +pinned to the baseline model. Each output is scored against a deterministic +scorer (exact match, substring, JSON schema, or numeric tolerance): + +``` +# starter (quality=balanced, baseline=gpt-4o) + +Problems: 20 · Frugal pass: 90.0% · Baseline pass: 95.0% · Δ: -5.0pp +Cost: frugal $0.0043 · baseline $0.0118 · savings 63.6% +``` + +No judge LLMs, no simulated numbers — these are the bytes your provider billed +you for. See [`config/workloads/starter.yaml`](./config/workloads/starter.yaml) +for the problem set and [`config/CAPABILITIES.md`](./config/CAPABILITIES.md) +for the methodology behind the capability scores the router uses. + +## Use cases routed today + +The starter catalog ships four use cases in [`config/use_cases/`](./config/use_cases/). +Set the `X-Frugal-Use-Case` header and the request routes to that bundle. + +| Use case | What it matches | Balanced tier chat model | +|---|---|---| +| `research-synthesis` | Long-form multi-source research | Claude Sonnet 4 | +| `code-dev` | Code generation, debugging, review | GPT-4.1 mini | +| `factual-qa` | Short factual lookups, trivia | GPT-4.1 nano | +| `structured-extraction` | Free text → JSON | Gemini 2.5 Flash | + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "X-Frugal-Use-Case: research-synthesis" \ + -H "X-Frugal-Quality: balanced" \ + -H "Content-Type: application/json" \ + -d '{"model":"auto","messages":[{"role":"user","content":"..."}]}' +``` + +Inspect the bundles directly: + +```bash +curl http://localhost:8080/v1/bundles # every use case +curl http://localhost:8080/v1/bundles/research-synthesis # default balanced tier +curl http://localhost:8080/v1/bundles/research-synthesis?quality=high +``` + +Bundles today are **curated** (`source: curated` in the YAML) with `as_of` dates +tracking when each was last refreshed. When `X-Frugal-Use-Case` is absent, the +classifier/router path runs as before — use-case routing is opt-in. + +## Toolchain capabilities + +| Capability | Status | +|---|---| +| Chat (model routing) | **Shipping** (Ring 1a) | +| Web search | Next (Ring 1b) | +| Reranking | After search (Ring 1c) | +| Content extraction | Roadmap | +| Browser use | Roadmap | + +Today's bundles populate the `chat` slot only; `search`, `rerank`, `extract`, +and `browser` slots exist in the YAML schema so the API shape is stable as +capabilities land. Each capability ships only when the eval harness has data +saying a bundle built around it clears the quality bar. --- ## How it works -Frugal wraps any command. It spins up a lightweight local proxy, sets `OPENAI_BASE_URL` to point at it, and routes every LLM request to the cheapest model that won't degrade quality. +Frugal wraps any command, spins up a lightweight local proxy, sets +`OPENAI_BASE_URL` to point at it, and routes every request through the router. ``` frugal python app.py │ ├─ starts proxy on a free port ├─ injects OPENAI_BASE_URL into your command's environment - ├─ classifies each request (complexity, domain, capabilities) - ├─ routes to cheapest model that clears the quality bar + ├─ classifies each request (use case + required capabilities) + ├─ picks the cheapest (model + toolchain) bundle that clears the bar + ├─ calls bundled tools — web search, rerank, extract, browse — as needed └─ shuts down proxy when your command exits ``` -A creative brainstorm doesn't need `o3`. A simple extraction doesn't need `claude-opus`. You're paying for capability you don't use on 60-80% of your LLM calls. +You're paying for capability you don't use on 60–80% of your AI calls, and when +you *do* need capability, a bare chat model is rarely the answer. ### What the classifier detects | Signal | How | -|--------|-----| +|---|---| | Code | Regex for code blocks, `function`/`def`/`class` keywords | | Math | LaTeX patterns, equation keywords | | Reasoning depth | System prompt complexity, conversation length | | Output format | JSON mode, tool/function calling | | Domain | Keyword detection (coding, creative, analysis, math) | +| Use case | Explicit header (`X-Frugal-Use-Case`) or inferred from signals above | -These signals combine into a complexity score. The router picks the cheapest model that exceeds the quality threshold for that score. +Signals combine into a complexity score and a capability set. The router picks +the cheapest bundle that clears every threshold. ## Install ```bash -curl -fsSL https://frugal.sh/install | sh +curl -fsSL https://frugal.sh/install | bash ``` -Downloads a single binary (~10MB), detects your API keys, adds `frugal` to your PATH. +Downloads a single signed binary (~10MB), verifies it with `cosign` if present, +detects your API keys, adds `frugal` to your `PATH`. Pin a version with +`FRUGAL_VERSION=v0.2.1 curl -fsSL … | bash`. ### From source ```bash -git clone https://github.com/frugalsh/frugal.git +git clone https://github.com/brainsparker/frugal.git cd frugal make build ``` @@ -72,24 +175,46 @@ frugal pytest tests/ frugal bash -c 'curl https://api.openai.com/v1/...' ``` -Frugal picks a free port, starts the proxy, sets `OPENAI_BASE_URL` in your command's environment, and cleans up on exit. Works with any OpenAI-compatible SDK — Python, Node, Go, Rust, curl. +Frugal picks a free port, starts the proxy, sets `OPENAI_BASE_URL` in your +command's environment, and cleans up on exit. Works with any OpenAI-compatible +SDK — Python, Node, Go, Rust, curl. ### Run as a server -If you want a persistent proxy (e.g., shared across terminals or in Docker): - ```bash frugal serve # or just: frugal (with no arguments) -``` -Then set the env var yourself: - -```bash export OPENAI_BASE_URL=http://localhost:8080/v1 ``` -### Quality thresholds +Optional hardening timeouts (Go duration syntax): + +- `FRUGAL_READ_HEADER_TIMEOUT` (default `5s`) +- `FRUGAL_READ_TIMEOUT` (default `15s`) +- `FRUGAL_WRITE_TIMEOUT` (default `120s`) +- `FRUGAL_IDLE_TIMEOUT` (default `60s`) +- `FRUGAL_MAX_HEADER_BYTES` (default `1048576`) + +### Auth, rate limits, logging (serve mode) + +| Env var | Default | Purpose | +|---|---|---| +| `FRUGAL_ADDR` | `127.0.0.1:8080` | Listen address. Non-loopback binds require `FRUGAL_AUTH_TOKEN` or `FRUGAL_ALLOW_UNAUTH=1`. | +| `FRUGAL_AUTH_TOKEN` | *(unset)* | Shared bearer token. When set, every `/v1/*` call must send `Authorization: Bearer $FRUGAL_AUTH_TOKEN`. | +| `FRUGAL_ALLOW_UNAUTH` | `0` | Escape hatch: setting to `1` allows unauthenticated binds on non-loopback. | +| `FRUGAL_RPS` | `30` | Global token-bucket rate in requests/sec. `0` disables. | +| `FRUGAL_BURST` | `60` | Token-bucket burst capacity. Clamped to `>= FRUGAL_RPS`. | +| `FRUGAL_MAX_COST_PER_REQUEST_USD` | `1.00` | Reject requests whose routing-time estimate exceeds this cap. Pinned requests bypass. `0` disables. | +| `FRUGAL_LOG_LEVEL` | `info` | `debug` / `info` / `warn` / `error`. | +| `FRUGAL_LOG_FORMAT` | `text` | `text` for human-readable, `json` for structured ingestion. | +| `FRUGAL_DECISION_BUFFER` | `1000` | Capacity of the async routing-decision ring buffer. | + +Prometheus metrics are served at `/metrics` behind the same auth as `/v1/*`. +All responses carry `X-Request-ID`; generate one client-side if you want to +correlate logs across your app and the proxy. + +### Quality tiers Control cost vs. quality per request: @@ -97,11 +222,11 @@ Control cost vs. quality per request: headers = {"X-Frugal-Quality": "cost"} # high | balanced | cost ``` -| Threshold | Behavior | -|-----------|----------| -| `high` | Top-tier models only. | -| `balanced` | Default. Best cost-quality tradeoff. | -| `cost` | Cheapest viable model. Maximum savings. | +| Tier | Behavior | +|---|---| +| `high` | Top-tier bundles only. Planners, complex reasoning, novel code. | +| `balanced` | Default. Best cost-quality tradeoff; right ~80% of the time. | +| `cost` | Cheapest viable bundle. Classification, extraction, simple summaries. | ### Model pinning @@ -120,40 +245,48 @@ response = client.chat.completions.create( headers = {"X-Frugal-Fallback": "gpt-4o,claude-sonnet-4-20250514,gemini-2.5-flash"} ``` -If the routed model errors, Frugal walks the chain. +If the routed model errors, Frugal walks the chain. At most the first 3 +fallback models are attempted, to bound latency and cost. Duplicate entries +and duplicates of the routed model are skipped. ## Supported models Pricing synced from [models.dev](https://models.dev) on every startup. | Provider | Models | -|----------|--------| +|---|---| | OpenAI | GPT-4o, GPT-4o-mini, GPT-4.1, GPT-4.1-mini, GPT-4.1-nano | | Anthropic | Claude Opus 4, Claude Sonnet 4, Claude Haiku 3.5 | | Google | Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 2.0 Flash | -Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, and/or `GOOGLE_API_KEY`. Frugal registers whichever providers have keys. +Set `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, and/or `GOOGLE_API_KEY`. Frugal +registers whichever providers have keys. Add models by editing `~/.frugal/config/models.yaml`. ## Commands | Command | What it does | -|---------|-------------| +|---|---| | `frugal ` | Wrap a command with the routing proxy | | `frugal serve` | Run the proxy as a persistent server | | `frugal sync` | Update model pricing from models.dev | +| `frugal bench` | Run the benchmark workload and print a cost/quality report | ## API -When running as a server, Frugal exposes an OpenAI-compatible API: +When running as a server, Frugal exposes an OpenAI-compatible API plus a few +Frugal-specific endpoints: | Endpoint | Description | -|----------|-------------| +|---|---| | `POST /v1/chat/completions` | Routed chat (streaming + non-streaming) | | `GET /v1/models` | List available models | -| `GET /v1/routing/explain` | Last routing decision | +| `GET /v1/bundles` | List every use case and its bundle | +| `GET /v1/bundles/{use-case}` | Bundle for a use case at a given quality tier | +| `GET /v1/routing/explain` | Last routing decision — model, toolchain, why | | `GET /health` | Health check | +| `GET /metrics` | Prometheus metrics (same auth as `/v1/*`) | ## Development @@ -164,6 +297,22 @@ make run # build + run server make release # cross-compile for all platforms ``` +## Security + +Release artifacts are cosign-signed (keyless, GitHub OIDC). The installer +verifies `SHA256SUMS` and, when `cosign` is present, the signature before +moving the binary into place. See [`SECURITY.md`](./SECURITY.md) for the +threat model and disclosure process. + ## License -MIT +[Business Source License 1.1](./LICENSE) — self-hosting and internal +production use are permitted; offering Frugal as a competing hosted routing +service is not. Each version converts to Apache 2.0 four years after release. +See the [BUSL FAQ](./LICENSE-BUSL-FAQ.md) for a plain-English summary. + +## Contributing + +Issues, bug reports, and PRs are welcome. See [`CONTRIBUTING.md`](./CONTRIBUTING.md) +for how the project is structured, the testing expectations, and the commit +style. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..662299f --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,60 @@ +# Security Policy + +Frugal is a local-first LLM proxy that handles provider API keys on behalf of +the operator. We take security reports seriously and aim to respond within +72 hours. + +## Reporting a vulnerability + +**Please do not open a public GitHub issue for security-sensitive reports.** + +Use GitHub's private vulnerability reporting to open a draft advisory against +this repository: + +https://github.com/frugalsh/frugal/security/advisories/new + +If you cannot use GitHub advisories, email **security@frugal.sh** with: + +- A description of the vulnerability and its impact. +- Reproduction steps or a proof-of-concept. +- The affected version(s) and environment. + +We will acknowledge receipt within 72 hours and coordinate a fix, disclosure +window, and CVE if applicable. + +## Supported versions + +Frugal is pre-1.0. Security fixes land on `main` and are cut into the next +tagged release. Only the latest release receives security updates; earlier +versions must upgrade. + +## Scope + +In-scope: + +- The `frugal` binary (`cmd/frugal`) and all packages under `internal/`. +- The installer script `install.sh` (supply-chain integrity). +- Default configuration shipped in `config/models.yaml`. +- Docker image `Dockerfile` and Fly deployment `fly.toml`. + +Out of scope: + +- Third-party provider APIs (OpenAI, Anthropic, Google) — report to those + vendors directly. +- Vulnerabilities in direct dependencies — report upstream, but we will + respond by bumping the dependency once patched. + +## Hardening posture + +Operational expectations for deployers: + +- Set `FRUGAL_AUTH_TOKEN` before exposing the proxy beyond `127.0.0.1`. +- Keep `FRUGAL_MAX_COST_PER_REQUEST_USD` at a sane ceiling for your account. +- Verify release artifacts with `cosign verify-blob` against the published + `SHA256SUMS` file. The installer does this automatically when `cosign` + is present. + +## Disclosure + +Once a fix is released, we will publish a GitHub Security Advisory with +the affected versions, the fix version, and — if applicable — a CVE ID. diff --git a/bin/frugal b/bin/frugal index 5576576..dd31551 100755 Binary files a/bin/frugal and b/bin/frugal differ diff --git a/cmd/frugal/bench.go b/cmd/frugal/bench.go new file mode 100644 index 0000000..995f115 --- /dev/null +++ b/cmd/frugal/bench.go @@ -0,0 +1,155 @@ +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/config" + "github.com/frugalsh/frugal/internal/eval" + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/provider/anthropic" + "github.com/frugalsh/frugal/internal/provider/google" + "github.com/frugalsh/frugal/internal/provider/openai" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +const defaultBenchWorkload = "config/workloads/starter.yaml" + +// runBench executes `frugal bench [-workload PATH] [-quality TIER] [-out FILE] +// [-timeout DUR]`. Returns the process exit code. +// +// Example: `frugal bench -quality balanced -out bench.md`. +func runBench(configPath string, args []string) int { + fs := flag.NewFlagSet("bench", flag.ContinueOnError) + fs.SetOutput(os.Stderr) + workloadPath := fs.String("workload", defaultBenchWorkload, "path to workload YAML") + qualityStr := fs.String("quality", "balanced", "quality tier (high | balanced | cost)") + outPath := fs.String("out", "", "write markdown report to this path in addition to stdout") + timeout := fs.Duration("timeout", 10*time.Minute, "overall bench timeout") + fs.Usage = func() { + fmt.Fprintln(os.Stderr, "Usage: frugal bench [flags]") + fmt.Fprintln(os.Stderr, "Measures cost and quality for every problem in the workload, calling") + fmt.Fprintln(os.Stderr, "real provider APIs. Requires whichever provider keys the workload's") + fmt.Fprintln(os.Stderr, "models + baseline need (OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY).") + fmt.Fprintln(os.Stderr) + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return 2 + } + + quality, ok := types.ParseQualityThreshold(*qualityStr) + if !ok { + fmt.Fprintf(os.Stderr, "frugal bench: unknown quality %q (want high | balanced | cost)\n", *qualityStr) + return 2 + } + + resolvedWorkload := *workloadPath + if !filepath.IsAbs(resolvedWorkload) { + if _, err := os.Stat(resolvedWorkload); err != nil { + // Fall back to the installed config dir when running from elsewhere. + if p := os.Getenv("FRUGAL_CONFIG_DIR"); p != "" { + alt := filepath.Join(p, filepath.Base(resolvedWorkload)) + if _, err := os.Stat(alt); err == nil { + resolvedWorkload = alt + } + } + } + } + + w, err := eval.LoadLiveWorkload(resolvedWorkload) + if err != nil { + fmt.Fprintf(os.Stderr, "frugal bench: %v\n", err) + return 1 + } + + cfg, err := config.Load(configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "frugal bench: load config: %v\n", err) + return 1 + } + + reg := provider.NewRegistry() + registerBenchProviders(cfg, reg) + if len(reg.AllModels()) == 0 { + fmt.Fprintln(os.Stderr, "frugal bench: no API keys found. Set OPENAI_API_KEY, ANTHROPIC_API_KEY, or GOOGLE_API_KEY.") + return 1 + } + + if _, err := reg.Resolve(w.Baseline); err != nil { + fmt.Fprintf(os.Stderr, "frugal bench: baseline model %q not registered (no key for its provider?): %v\n", + w.Baseline, err) + return 1 + } + + modelEntries, thresholds := router.BuildTaxonomy(cfg) + modelEntries = filterRegisteredModels(modelEntries, reg) + if len(modelEntries) == 0 { + fmt.Fprintln(os.Stderr, "frugal bench: no routable models available for registered providers") + return 1 + } + rtr := router.New(modelEntries, thresholds) + cls := classifier.NewRuleBased() + runner := eval.NewLiveRunner(cfg, cls, rtr, reg) + + ctx, cancel := context.WithTimeout(context.Background(), *timeout) + defer cancel() + + fmt.Fprintf(os.Stderr, "running %d problems from %q against baseline %s @ quality=%s...\n", + len(w.Problems), w.Name, w.Baseline, quality) + summary, err := runner.Run(ctx, w, quality) + if err != nil { + fmt.Fprintf(os.Stderr, "frugal bench: %v\n", err) + return 1 + } + + // Always write to stdout. + var buf bytes.Buffer + if err := eval.WriteLiveMarkdown(&buf, summary); err != nil { + fmt.Fprintf(os.Stderr, "frugal bench: render report: %v\n", err) + return 1 + } + fmt.Print(buf.String()) + + if *outPath != "" { + if err := os.WriteFile(*outPath, buf.Bytes(), 0o644); err != nil { + fmt.Fprintf(os.Stderr, "frugal bench: write %s: %v\n", *outPath, err) + return 1 + } + fmt.Fprintf(os.Stderr, "wrote %s\n", *outPath) + } + + // Print a one-line summary on stderr for CI log scraping. + fmt.Fprintf(os.Stderr, + "bench done: %d problems · frugal %.1f%% pass · baseline %.1f%% pass · savings %.1f%%\n", + summary.ProblemCount, summary.FrugalPassRate, summary.BaselinePassRate, summary.SavingsPct) + return 0 +} + +// registerBenchProviders mirrors runWrap's provider registration without the +// retry wrapper: benchmarks should see raw upstream behavior so retries don't +// mask quality/rate-limit issues. +func registerBenchProviders(cfg *config.Config, reg *provider.Registry) { + if pc, ok := cfg.Providers["openai"]; ok { + if key := os.Getenv(pc.APIKeyEnv); key != "" { + reg.Register(openai.New(key, pc.BaseURL, modelNames(pc))) + } + } + if pc, ok := cfg.Providers["anthropic"]; ok { + if key := os.Getenv(pc.APIKeyEnv); key != "" { + reg.Register(anthropic.New(key, pc.BaseURL, modelNames(pc))) + } + } + if pc, ok := cfg.Providers["google"]; ok { + if key := os.Getenv(pc.APIKeyEnv); key != "" { + reg.Register(google.New(key, pc.BaseURL, modelNames(pc))) + } + } +} diff --git a/cmd/frugal/main.go b/cmd/frugal/main.go index 51cf2d6..752ec3f 100644 --- a/cmd/frugal/main.go +++ b/cmd/frugal/main.go @@ -1,23 +1,40 @@ package main import ( + "context" + "encoding/json" + "errors" + "fmt" "log" + "log/slog" + "net" "net/http" "os" + "os/signal" + "runtime/debug" + "strconv" + "syscall" + "time" "github.com/go-chi/chi/v5" "github.com/frugalsh/frugal/internal/classifier" "github.com/frugalsh/frugal/internal/config" + "github.com/frugalsh/frugal/internal/metrics" + "github.com/frugalsh/frugal/internal/obs" "github.com/frugalsh/frugal/internal/provider" "github.com/frugalsh/frugal/internal/provider/anthropic" "github.com/frugalsh/frugal/internal/provider/google" "github.com/frugalsh/frugal/internal/provider/openai" "github.com/frugalsh/frugal/internal/proxy" "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/usecase" ) func main() { + obs.InitLogger() + metrics.Register() + configPath := "config/models.yaml" if p := os.Getenv("FRUGAL_CONFIG"); p != "" { configPath = p @@ -26,11 +43,19 @@ func main() { // Handle subcommands if len(os.Args) > 1 { switch os.Args[1] { + case "-h", "--help", "help": + printHelp() + return + case "-v", "--version", "version": + fmt.Println(version()) + return case "sync": if err := runSync(configPath); err != nil { log.Fatalf("sync failed: %v", err) } return + case "bench": + os.Exit(runBench(configPath, os.Args[2:])) case "serve": // fall through to server startup default: @@ -42,7 +67,7 @@ func main() { // Sync pricing from models.dev on startup (non-fatal if it fails) if err := runSync(configPath); err != nil { - log.Printf("warning: pricing sync failed (using cached config): %v", err) + slog.Warn("pricing sync failed; using cached config", "err", err) } cfg, err := config.Load(configPath) @@ -56,60 +81,296 @@ func main() { if pc, ok := cfg.Providers["openai"]; ok { if key := os.Getenv(pc.APIKeyEnv); key != "" { models := modelNames(pc) - registry.Register(openai.New(key, pc.BaseURL, models)) - log.Printf("registered openai provider with %d models", len(models)) + registry.Register(provider.WithRetry(openai.New(key, pc.BaseURL, models))) + slog.Info("registered provider", "provider", "openai", "models", len(models)) } } if pc, ok := cfg.Providers["anthropic"]; ok { if key := os.Getenv(pc.APIKeyEnv); key != "" { models := modelNames(pc) - registry.Register(anthropic.New(key, pc.BaseURL, models)) - log.Printf("registered anthropic provider with %d models", len(models)) + registry.Register(provider.WithRetry(anthropic.New(key, pc.BaseURL, models))) + slog.Info("registered provider", "provider", "anthropic", "models", len(models)) } } if pc, ok := cfg.Providers["google"]; ok { if key := os.Getenv(pc.APIKeyEnv); key != "" { models := modelNames(pc) - registry.Register(google.New(key, pc.BaseURL, models)) - log.Printf("registered google provider with %d models", len(models)) + registry.Register(provider.WithRetry(google.New(key, pc.BaseURL, models))) + slog.Info("registered provider", "provider", "google", "models", len(models)) } } // Build classifier and router cls := classifier.NewRuleBased() modelEntries, thresholds := router.BuildTaxonomy(cfg) + modelEntries = filterRegisteredModels(modelEntries, registry) + if len(modelEntries) == 0 { + log.Fatal("no routable models available for registered providers") + } rtr := router.New(modelEntries, thresholds) + // Load use-case registry. Dir resolves relative to the binary's working + // directory; override via FRUGAL_USE_CASES_DIR. An absent or empty dir + // disables use-case routing without failing startup. + useCaseDir := os.Getenv("FRUGAL_USE_CASES_DIR") + if useCaseDir == "" { + useCaseDir = "config/use_cases" + } + useCases, err := usecase.Load(useCaseDir) + if err != nil { + log.Fatalf("failed to load use cases from %s: %v", useCaseDir, err) + } + slog.Info("use cases loaded", "dir", useCaseDir, "count", useCases.Len()) + // Build HTTP handler - h := proxy.NewHandler(cls, rtr, registry) + h := proxy.NewHandlerWithUseCases(cls, rtr, registry, useCases) + + addr := "127.0.0.1:8080" + if a := os.Getenv("FRUGAL_ADDR"); a != "" { + addr = a + } + + authToken := os.Getenv("FRUGAL_AUTH_TOKEN") + if err := guardUnauthenticatedBind(addr, authToken); err != nil { + log.Fatalf("startup rejected: %v", err) + } + + rps := envIntOrDefault("FRUGAL_RPS", 30) + burst := envIntOrDefault("FRUGAL_BURST", 60) - // Wire routes + // Wire routes. Middleware ordering matters: RequestID first so + // Recover/Logging carry the ID; Auth before any handler touches + // registry; HeaderExtraction last so per-request controls land on the + // authenticated ctx. r := chi.NewRouter() + r.Use(proxy.RequestIDMiddleware) + r.Use(proxy.RecoverMiddleware) r.Use(proxy.LoggingMiddleware) + r.Use(proxy.RateLimitMiddleware(rps, burst)) + r.Use(proxy.AuthMiddleware(authToken)) r.Use(proxy.HeaderExtractionMiddleware) r.Post("/v1/chat/completions", h.ChatCompletions) r.Get("/v1/models", h.ListModels) r.Get("/v1/routing/explain", h.RoutingExplain) + r.Get("/v1/bundles", h.ListBundles) + r.Get("/v1/bundles/{useCase}", h.GetBundle) - // Health check - r.Get("/health", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("ok")) - }) + // Health check — always unauthenticated so deployment probes keep working. + // Reports provider list + model count so operators can distinguish "server + // up" from "server up with valid routing". Returns 503 when no models are + // routable so load balancers take the instance out of rotation. + r.Get("/health", healthHandler(registry)) - addr := ":8080" - if a := os.Getenv("FRUGAL_ADDR"); a != "" { - addr = a - } + // Prometheus metrics. Sits behind the same auth middleware as /v1, so + // anyone with scrape creds also has chat creds — which is the usual ops + // posture for a small internal proxy. If operators want separate scrape + // access they can run a second listener with its own token later. + r.Handle("/metrics", metrics.Handler()) - log.Printf("frugal listening on %s", addr) - if err := http.ListenAndServe(addr, r); err != nil { + server := newHTTPServer(addr, r) + + slog.Info("frugal listening", "addr", addr, "auth", authToken != "") + if err := runServer(server); err != nil { log.Fatalf("server error: %v", err) } } +// guardUnauthenticatedBind refuses to start an unauthenticated proxy on a +// non-loopback interface unless the operator has explicitly opted in via +// FRUGAL_ALLOW_UNAUTH=1. The check keeps "no API keys in config? just run it" +// working on localhost while preventing the Fly/Docker footgun where :8080 +// binds to 0.0.0.0 and any network traffic can drain the operator's keys. +func guardUnauthenticatedBind(addr, token string) error { + if token != "" { + return nil + } + if os.Getenv("FRUGAL_ALLOW_UNAUTH") == "1" { + log.Printf("warning: FRUGAL_ALLOW_UNAUTH=1 set — running without auth on %s", addr) + return nil + } + if isLoopbackBind(addr) { + return nil + } + return &startupError{msg: "refusing to bind " + addr + " without FRUGAL_AUTH_TOKEN; set a token or FRUGAL_ALLOW_UNAUTH=1 to override"} +} + +// isLoopbackBind reports whether addr binds only to the loopback interface. +// Accepts forms like "127.0.0.1:8080", "[::1]:8080", "localhost:8080". +func isLoopbackBind(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + switch host { + case "127.0.0.1", "localhost", "::1": + return true + } + if ip := net.ParseIP(host); ip != nil && ip.IsLoopback() { + return true + } + return false +} + +type startupError struct{ msg string } + +func (e *startupError) Error() string { return e.msg } + +// printHelp prints a one-page summary of commands and flags. +func printHelp() { + fmt.Println(`frugal — open-source LLM cost-optimizing proxy + +Usage: + frugal [args...] Wrap any command with the routing proxy + frugal serve Run the proxy as a persistent server + frugal sync Refresh model pricing from models.dev + frugal bench [flags] Run cost + quality benchmark against providers + frugal -v | --version Print the build version + frugal -h | --help Show this help + +Common environment: + FRUGAL_ADDR Listen address (serve; default 127.0.0.1:8080) + FRUGAL_AUTH_TOKEN Shared bearer token required on non-loopback + FRUGAL_LOG_LEVEL debug | info | warn | error + FRUGAL_LOG_FORMAT text | json + FRUGAL_MAX_COST_PER_REQUEST_USD Per-request spend cap (default 1.00) + OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY Provider credentials + +See README.md for the full list.`) +} + +// buildVersion is injected at release time via +// +// go build -ldflags "-X main.buildVersion=$VERSION" +// +// (see Makefile). It takes precedence over debug.ReadBuildInfo so release +// binaries built with `go build` — not `go install` — still report a real +// tag. Left empty for local `go run` / `go build` without ldflags. +var buildVersion string + +// version reports a human-readable build identifier. +func version() string { + if buildVersion != "" { + return buildVersion + } + if info, ok := debug.ReadBuildInfo(); ok { + if info.Main.Version != "" && info.Main.Version != "(devel)" { + return info.Main.Version + } + } + return "dev" +} + +// healthHandler reports liveness + a shallow inventory of routable models. +// Operators (and Fly/K8s) distinguish "HTTP is up" from "routing is actually +// healthy" — the latter requires at least one registered model. +func healthHandler(registry *provider.Registry) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + models := registry.AllModels() + providers := map[string]bool{} + for _, m := range models { + if p, err := registry.Resolve(m); err == nil { + providers[p.Name()] = true + } + } + names := make([]string, 0, len(providers)) + for n := range providers { + names = append(names, n) + } + + status := "ok" + code := http.StatusOK + if len(models) == 0 { + status = "degraded" + code = http.StatusServiceUnavailable + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": status, + "providers": names, + "models": len(models), + }) + } +} + +// runServer starts the HTTP server and waits for SIGINT/SIGTERM to trigger a +// graceful shutdown. In-flight requests finish (bounded to shutdownTimeout) +// and the listener is closed. Returns nil on clean shutdown. +func runServer(server *http.Server) error { + const shutdownTimeout = 30 * time.Second + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + errCh := make(chan error, 1) + go func() { + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + return + } + errCh <- nil + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + slog.Info("shutdown signal received; draining in-flight requests") + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer shutdownCancel() + if err := server.Shutdown(shutdownCtx); err != nil { + slog.Warn("server shutdown returned error", "err", err) + } + return nil + } +} + +func newHTTPServer(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: envDurationOrDefault("FRUGAL_READ_HEADER_TIMEOUT", 5*time.Second), + ReadTimeout: envDurationOrDefault("FRUGAL_READ_TIMEOUT", 15*time.Second), + WriteTimeout: envDurationOrDefault("FRUGAL_WRITE_TIMEOUT", 120*time.Second), + IdleTimeout: envDurationOrDefault("FRUGAL_IDLE_TIMEOUT", 60*time.Second), + MaxHeaderBytes: envIntOrDefault("FRUGAL_MAX_HEADER_BYTES", http.DefaultMaxHeaderBytes), + } +} + +func envDurationOrDefault(key string, fallback time.Duration) time.Duration { + value := os.Getenv(key) + if value == "" { + return fallback + } + + parsed, err := time.ParseDuration(value) + if err != nil || parsed <= 0 { + slog.Warn("invalid env duration; using default", "key", key, "value", value, "default", fallback.String()) + return fallback + } + + return parsed +} + +func envIntOrDefault(key string, fallback int) int { + value := os.Getenv(key) + if value == "" { + return fallback + } + + parsed, err := strconv.Atoi(value) + if err != nil || parsed <= 0 { + slog.Warn("invalid env int; using default", "key", key, "value", value, "default", fallback) + return fallback + } + + return parsed +} + func modelNames(pc config.ProviderConfig) []string { names := make([]string, 0, len(pc.Models)) for name := range pc.Models { @@ -117,3 +378,13 @@ func modelNames(pc config.ProviderConfig) []string { } return names } + +func filterRegisteredModels(entries []router.ModelEntry, registry *provider.Registry) []router.ModelEntry { + filtered := make([]router.ModelEntry, 0, len(entries)) + for _, entry := range entries { + if _, err := registry.Resolve(entry.Name); err == nil { + filtered = append(filtered, entry) + } + } + return filtered +} diff --git a/cmd/frugal/main_test.go b/cmd/frugal/main_test.go new file mode 100644 index 0000000..54b7261 --- /dev/null +++ b/cmd/frugal/main_test.go @@ -0,0 +1,138 @@ +package main + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +type testProvider struct { + name string + models []string +} + +func (p *testProvider) Name() string { return p.name } +func (p *testProvider) Models() []string { return p.models } + +func (p *testProvider) ChatCompletion(_ context.Context, _ string, _ *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + return &types.ChatCompletionResponse{}, nil +} + +func (p *testProvider) ChatCompletionStream(_ context.Context, _ string, _ *types.ChatCompletionRequest) (<-chan provider.StreamChunk, error) { + ch := make(chan provider.StreamChunk) + close(ch) + return ch, nil +} + +func TestFilterRegisteredModels(t *testing.T) { + reg := provider.NewRegistry() + reg.Register(&testProvider{name: "openai", models: []string{"gpt-4o-mini"}}) + + entries := []router.ModelEntry{ + {Name: "gpt-4o-mini", Provider: "openai"}, + {Name: "claude-sonnet-4-20250514", Provider: "anthropic"}, + } + + filtered := filterRegisteredModels(entries, reg) + if got := len(filtered); got != 1 { + t.Fatalf("expected 1 registered model, got %d", got) + } + if filtered[0].Name != "gpt-4o-mini" { + t.Fatalf("expected gpt-4o-mini to remain, got %s", filtered[0].Name) + } +} + +func TestNewHTTPServerDefaults(t *testing.T) { + t.Setenv("FRUGAL_READ_HEADER_TIMEOUT", "") + t.Setenv("FRUGAL_READ_TIMEOUT", "") + t.Setenv("FRUGAL_WRITE_TIMEOUT", "") + t.Setenv("FRUGAL_IDLE_TIMEOUT", "") + t.Setenv("FRUGAL_MAX_HEADER_BYTES", "") + + srv := newHTTPServer(":8080", http.NewServeMux()) + + if srv.ReadHeaderTimeout != 5*time.Second { + t.Fatalf("expected default read header timeout 5s, got %s", srv.ReadHeaderTimeout) + } + if srv.ReadTimeout != 15*time.Second { + t.Fatalf("expected default read timeout 15s, got %s", srv.ReadTimeout) + } + if srv.WriteTimeout != 120*time.Second { + t.Fatalf("expected default write timeout 120s, got %s", srv.WriteTimeout) + } + if srv.IdleTimeout != 60*time.Second { + t.Fatalf("expected default idle timeout 60s, got %s", srv.IdleTimeout) + } + if srv.MaxHeaderBytes != http.DefaultMaxHeaderBytes { + t.Fatalf("expected default max header bytes %d, got %d", http.DefaultMaxHeaderBytes, srv.MaxHeaderBytes) + } +} + +func TestNewHTTPServerEnvOverrides(t *testing.T) { + t.Setenv("FRUGAL_READ_HEADER_TIMEOUT", "6s") + t.Setenv("FRUGAL_READ_TIMEOUT", "20s") + t.Setenv("FRUGAL_WRITE_TIMEOUT", "150s") + t.Setenv("FRUGAL_IDLE_TIMEOUT", "75s") + t.Setenv("FRUGAL_MAX_HEADER_BYTES", "65536") + + srv := newHTTPServer(":8080", http.NewServeMux()) + + if srv.ReadHeaderTimeout != 6*time.Second { + t.Fatalf("expected read header timeout 6s, got %s", srv.ReadHeaderTimeout) + } + if srv.ReadTimeout != 20*time.Second { + t.Fatalf("expected read timeout 20s, got %s", srv.ReadTimeout) + } + if srv.WriteTimeout != 150*time.Second { + t.Fatalf("expected write timeout 150s, got %s", srv.WriteTimeout) + } + if srv.IdleTimeout != 75*time.Second { + t.Fatalf("expected idle timeout 75s, got %s", srv.IdleTimeout) + } + if srv.MaxHeaderBytes != 65536 { + t.Fatalf("expected max header bytes 65536, got %d", srv.MaxHeaderBytes) + } +} + +func TestEnvDurationOrDefaultInvalidValues(t *testing.T) { + const key = "FRUGAL_TIMEOUT_TEST" + + t.Setenv(key, "not-a-duration") + if got := envDurationOrDefault(key, 3*time.Second); got != 3*time.Second { + t.Fatalf("expected fallback for invalid duration, got %s", got) + } + + t.Setenv(key, "0s") + if got := envDurationOrDefault(key, 3*time.Second); got != 3*time.Second { + t.Fatalf("expected fallback for zero duration, got %s", got) + } + + t.Setenv(key, "-2s") + if got := envDurationOrDefault(key, 3*time.Second); got != 3*time.Second { + t.Fatalf("expected fallback for negative duration, got %s", got) + } +} + +func TestEnvIntOrDefaultInvalidValues(t *testing.T) { + const key = "FRUGAL_INT_TEST" + + t.Setenv(key, "not-an-int") + if got := envIntOrDefault(key, 1234); got != 1234 { + t.Fatalf("expected fallback for invalid int, got %d", got) + } + + t.Setenv(key, "0") + if got := envIntOrDefault(key, 1234); got != 1234 { + t.Fatalf("expected fallback for zero int, got %d", got) + } + + t.Setenv(key, "-10") + if got := envIntOrDefault(key, 1234); got != 1234 { + t.Fatalf("expected fallback for negative int, got %d", got) + } +} diff --git a/cmd/frugal/sync.go b/cmd/frugal/sync.go index 1f4d02b..01e66bb 100644 --- a/cmd/frugal/sync.go +++ b/cmd/frugal/sync.go @@ -1,10 +1,11 @@ package main import ( + "context" "fmt" - "log" + "log/slog" "os" - "strings" + "path/filepath" "github.com/frugalsh/frugal/internal/config" msync "github.com/frugalsh/frugal/internal/sync" @@ -22,13 +23,13 @@ var modelAliases = map[string][]string{ } func runSync(configPath string) error { - log.Println("fetching model pricing from models.dev...") + slog.Info("fetching model pricing from models.dev") - catalog, err := msync.FetchModels() + catalog, err := msync.FetchModels(context.Background()) if err != nil { return fmt.Errorf("fetch failed: %w", err) } - log.Printf("fetched %d model entries from models.dev", len(catalog)) + slog.Info("fetched models.dev catalog", "entries", len(catalog)) cfg, err := config.Load(configPath) if err != nil { @@ -42,21 +43,22 @@ func runSync(configPath string) error { for modelName, mc := range pc.Models { entry, found := lookupModel(catalog, providerName, modelName) if !found { - log.Printf(" [skip] %s/%s — not found in models.dev", providerName, modelName) + slog.Info("sync skipped", "provider", providerName, "model", modelName, "reason", "not_in_catalog") notFound++ continue } changed := false + logger := slog.With("provider", providerName, "model", modelName) if entry.Cost != nil { newInput := msync.CostPer1K(entry.Cost.Input) newOutput := msync.CostPer1K(entry.Cost.Output) if newInput != mc.CostPer1KInput || newOutput != mc.CostPer1KOutput { - log.Printf(" [update] %s/%s: input $%.6f→$%.6f, output $%.6f→$%.6f per 1K tokens", - providerName, modelName, - mc.CostPer1KInput, newInput, - mc.CostPer1KOutput, newOutput) + logger.Info("cost updated", + "input_from", mc.CostPer1KInput, "input_to", newInput, + "output_from", mc.CostPer1KOutput, "output_to", newOutput, + ) mc.CostPer1KInput = newInput mc.CostPer1KOutput = newOutput changed = true @@ -64,19 +66,18 @@ func runSync(configPath string) error { } if entry.Limit != nil && entry.Limit.Context > 0 && entry.Limit.Context != mc.Capabilities.MaxContext { - log.Printf(" [update] %s/%s: context %d→%d", - providerName, modelName, mc.Capabilities.MaxContext, entry.Limit.Context) + logger.Info("context updated", "from", mc.Capabilities.MaxContext, "to", entry.Limit.Context) mc.Capabilities.MaxContext = entry.Limit.Context changed = true } if entry.ToolCall != mc.Capabilities.ToolUse { - log.Printf(" [update] %s/%s: tool_use %v→%v", providerName, modelName, mc.Capabilities.ToolUse, entry.ToolCall) + logger.Info("tool_use updated", "from", mc.Capabilities.ToolUse, "to", entry.ToolCall) mc.Capabilities.ToolUse = entry.ToolCall changed = true } if entry.StructuredOutput != mc.Capabilities.JSONMode { - log.Printf(" [update] %s/%s: json_mode %v→%v", providerName, modelName, mc.Capabilities.JSONMode, entry.StructuredOutput) + logger.Info("json_mode updated", "from", mc.Capabilities.JSONMode, "to", entry.StructuredOutput) mc.Capabilities.JSONMode = entry.StructuredOutput changed = true } @@ -85,22 +86,26 @@ func runSync(configPath string) error { pc.Models[modelName] = mc updated++ } else { - log.Printf(" [ok] %s/%s — up to date", providerName, modelName) + logger.Debug("model up to date") } } cfg.Providers[providerName] = pc } - log.Printf("updated %d models, %d not found in catalog", updated, notFound) + slog.Info("sync complete", "updated", updated, "not_found", notFound) if updated > 0 { return writeConfig(configPath, cfg) } - log.Println("no changes needed") + slog.Info("sync: no changes") return nil } +// lookupModel resolves a configured model to a models.dev catalog entry by +// exact key or explicit alias. Fuzzy (strings.Contains) matching is +// intentionally absent: it silently cross-bound prices (e.g. gpt-4 → gpt-4o) +// because the map iteration is unordered. func lookupModel(catalog map[string]msync.ModelsDevEntry, providerName, modelName string) (msync.ModelsDevEntry, bool) { // 1. Try "provider/model" (e.g., "openai/gpt-4o") if entry, ok := catalog[providerName+"/"+modelName]; ok { @@ -118,37 +123,55 @@ func lookupModel(catalog map[string]msync.ModelsDevEntry, providerName, modelNam if entry, ok := catalog[alias]; ok { return entry, true } - // Also try with provider prefix if entry, ok := catalog[providerName+"/"+alias]; ok { return entry, true } } } - // 4. Fuzzy: find catalog entry containing the model name or vice versa - for id, entry := range catalog { - bare := id - if idx := strings.LastIndex(id, "/"); idx >= 0 { - bare = id[idx+1:] - } - if strings.Contains(bare, modelName) || strings.Contains(modelName, bare) { - return entry, true - } - } - return msync.ModelsDevEntry{}, false } +// writeConfig atomically replaces the config file: write to a sibling +// tempfile, fsync, then rename. An interrupted sync never leaves the user's +// models.yaml truncated or partially written. func writeConfig(path string, cfg *config.Config) error { data, err := yaml.Marshal(cfg) if err != nil { return fmt.Errorf("marshal config: %w", err) } - if err := os.WriteFile(path, data, 0644); err != nil { - return fmt.Errorf("write config: %w", err) + dir := filepath.Dir(path) + tmp, err := os.CreateTemp(dir, filepath.Base(path)+".tmp-*") + if err != nil { + return fmt.Errorf("create tempfile: %w", err) + } + tmpPath := tmp.Name() + cleanup := func() { _ = os.Remove(tmpPath) } + + if _, err := tmp.Write(data); err != nil { + tmp.Close() + cleanup() + return fmt.Errorf("write tempfile: %w", err) + } + if err := tmp.Sync(); err != nil { + tmp.Close() + cleanup() + return fmt.Errorf("fsync tempfile: %w", err) + } + if err := tmp.Close(); err != nil { + cleanup() + return fmt.Errorf("close tempfile: %w", err) + } + if err := os.Chmod(tmpPath, 0644); err != nil { + cleanup() + return fmt.Errorf("chmod tempfile: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + cleanup() + return fmt.Errorf("rename tempfile: %w", err) } - log.Printf("wrote updated config to %s", path) + slog.Info("wrote config", "path", path) return nil } diff --git a/cmd/frugal/sync_test.go b/cmd/frugal/sync_test.go new file mode 100644 index 0000000..e08f11b --- /dev/null +++ b/cmd/frugal/sync_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "testing" + + msync "github.com/frugalsh/frugal/internal/sync" +) + +// TestLookupModel_NoFuzzyCrossBinding guards the sync path against the +// regression that shipped in an earlier version: a strings.Contains fallback +// matched gpt-4 → gpt-4o, silently overwriting local pricing with a different +// model's cost. Ensure the lookup only succeeds on exact keys or explicit +// aliases. +func TestLookupModel_NoFuzzyCrossBinding(t *testing.T) { + catalog := map[string]msync.ModelsDevEntry{ + "openai/gpt-4o": {ID: "gpt-4o"}, + "gpt-4o": {ID: "gpt-4o"}, + } + + // Exact match should succeed. + if _, ok := lookupModel(catalog, "openai", "gpt-4o"); !ok { + t.Fatalf("expected exact match for openai/gpt-4o") + } + + // A different model name must NOT pick up gpt-4o's entry via substring. + if entry, ok := lookupModel(catalog, "openai", "gpt-4"); ok { + t.Fatalf("unexpected fuzzy match: gpt-4 resolved to %+v", entry) + } + if entry, ok := lookupModel(catalog, "openai", "gpt-4o-mini"); ok { + t.Fatalf("unexpected fuzzy match: gpt-4o-mini resolved to %+v", entry) + } +} + +func TestLookupModel_AliasesStillResolve(t *testing.T) { + catalog := map[string]msync.ModelsDevEntry{ + "claude-3-5-haiku": {ID: "claude-3-5-haiku"}, + } + + entry, ok := lookupModel(catalog, "anthropic", "claude-haiku-3.5") + if !ok { + t.Fatalf("expected alias claude-haiku-3.5 → claude-3-5-haiku to resolve") + } + if entry.ID != "claude-3-5-haiku" { + t.Fatalf("alias resolved to wrong entry: %+v", entry) + } +} diff --git a/cmd/frugal/wrap.go b/cmd/frugal/wrap.go index d2c55ec..736bcb5 100644 --- a/cmd/frugal/wrap.go +++ b/cmd/frugal/wrap.go @@ -1,8 +1,13 @@ package main import ( + "context" + "crypto/rand" + "encoding/base32" + "errors" "fmt" "log" + "log/slog" "net" "net/http" "os" @@ -21,6 +26,7 @@ import ( provopenai "github.com/frugalsh/frugal/internal/provider/openai" "github.com/frugalsh/frugal/internal/proxy" "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/usecase" ) // runWrap starts the proxy on a free port, runs the given command with @@ -28,7 +34,7 @@ import ( func runWrap(configPath string, args []string) int { // Sync pricing if err := runSync(configPath); err != nil { - log.Printf("warning: pricing sync failed: %v", err) + slog.Warn("pricing sync failed", "err", err) } cfg, err := config.Load(configPath) @@ -47,14 +53,46 @@ func runWrap(configPath string, args []string) int { cls := classifier.NewRuleBased() modelEntries, thresholds := router.BuildTaxonomy(cfg) + modelEntries = filterRegisteredModels(modelEntries, registry) + if len(modelEntries) == 0 { + fmt.Fprintln(os.Stderr, "frugal: no routable models available for registered providers") + return 1 + } rtr := router.New(modelEntries, thresholds) - h := proxy.NewHandler(cls, rtr, registry) + // Load use cases (same path resolution as serve mode; empty dir is + // allowed and silently disables use-case routing). + useCaseDir := os.Getenv("FRUGAL_USE_CASES_DIR") + if useCaseDir == "" { + useCaseDir = "config/use_cases" + } + useCases, err := usecase.Load(useCaseDir) + if err != nil { + fmt.Fprintf(os.Stderr, "frugal: failed to load use cases from %s: %v\n", useCaseDir, err) + return 1 + } + + h := proxy.NewHandlerWithUseCases(cls, rtr, registry, useCases) + + // Wrap mode always binds loopback, but shared machines still expose the + // port to any local user. Generate a one-shot bearer token per wrap, seal + // the proxy behind it, and hand the same token to the child via + // OPENAI_API_KEY so the SDK authenticates transparently. The user's real + // upstream key stays in Frugal's environment and never touches the child. + authToken := os.Getenv("FRUGAL_AUTH_TOKEN") + if authToken == "" { + authToken = newSessionToken() + } r := chi.NewRouter() + r.Use(proxy.RequestIDMiddleware) + r.Use(proxy.RecoverMiddleware) + r.Use(proxy.AuthMiddleware(authToken)) r.Use(proxy.HeaderExtractionMiddleware) r.Post("/v1/chat/completions", h.ChatCompletions) r.Get("/v1/models", h.ListModels) r.Get("/v1/routing/explain", h.RoutingExplain) + r.Get("/v1/bundles", h.ListBundles) + r.Get("/v1/bundles/{useCase}", h.GetBundle) r.Get("/health", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("ok")) }) @@ -71,86 +109,138 @@ func runWrap(configPath string, args []string) int { // Start proxy in background server := &http.Server{Handler: r} go func() { - if err := server.Serve(listener); err != http.ErrServerClosed { - log.Printf("proxy error: %v", err) + if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + slog.Warn("proxy serve error", "err", err) } }() - // Wait for proxy to be ready - waitForReady(fmt.Sprintf("http://127.0.0.1:%d/health", port)) + // Wait for proxy to be ready (bounded; fail hard if unreachable). + if err := waitForReady(fmt.Sprintf("http://127.0.0.1:%d/health", port)); err != nil { + fmt.Fprintf(os.Stderr, "frugal: proxy did not become ready: %v\n", err) + _ = server.Close() + return 1 + } fmt.Fprintf(os.Stderr, "frugal: proxy running on :%d → routing across %d models\n", port, len(registry.AllModels())) - // Run the user's command with OPENAI_BASE_URL set + // Run the user's command with OPENAI_BASE_URL set. OPENAI_API_KEY is + // overwritten with the proxy's session token — the real upstream key + // stays in Frugal's environment only. cmd := exec.Command(args[0], args[1:]...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - cmd.Env = injectEnv(os.Environ(), baseURL) + cmd.Env = injectEnv(os.Environ(), baseURL, authToken) + + if err := cmd.Start(); err != nil { + fmt.Fprintf(os.Stderr, "frugal: failed to start command: %v\n", err) + _ = server.Close() + return 1 + } - // Forward signals to the child process + // Forward signals to the child. signal.Notify installed AFTER Start so + // signals never race a half-spawned process. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { for sig := range sigCh { - if cmd.Process != nil { - cmd.Process.Signal(sig) + if p := cmd.Process; p != nil { + _ = p.Signal(sig) } } }() exitCode := 0 - if err := cmd.Run(); err != nil { + if err := cmd.Wait(); err != nil { if exitErr, ok := err.(*exec.ExitError); ok { exitCode = exitErr.ExitCode() } else { - fmt.Fprintf(os.Stderr, "frugal: failed to run command: %v\n", err) + fmt.Fprintf(os.Stderr, "frugal: command exited with error: %v\n", err) exitCode = 1 } } - - // Shut down proxy - server.Close() + signal.Stop(sigCh) + close(sigCh) + + // Drain in-flight proxy requests before the wrapped command's exit code + // propagates back to the caller. + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + slog.Warn("proxy shutdown returned error", "err", err) + } return exitCode } func registerProviders(cfg *config.Config, registry *provider.Registry) { if pc, ok := cfg.Providers["openai"]; ok { if key := os.Getenv(pc.APIKeyEnv); key != "" { - registry.Register(provopenai.New(key, pc.BaseURL, modelNames(pc))) + registry.Register(provider.WithRetry(provopenai.New(key, pc.BaseURL, modelNames(pc)))) } } if pc, ok := cfg.Providers["anthropic"]; ok { if key := os.Getenv(pc.APIKeyEnv); key != "" { - registry.Register(provanthropic.New(key, pc.BaseURL, modelNames(pc))) + registry.Register(provider.WithRetry(provanthropic.New(key, pc.BaseURL, modelNames(pc)))) } } if pc, ok := cfg.Providers["google"]; ok { if key := os.Getenv(pc.APIKeyEnv); key != "" { - registry.Register(provgoogle.New(key, pc.BaseURL, modelNames(pc))) + registry.Register(provider.WithRetry(provgoogle.New(key, pc.BaseURL, modelNames(pc)))) } } } -func injectEnv(environ []string, baseURL string) []string { - out := make([]string, 0, len(environ)+2) - for _, e := range environ { - // Don't override if user already set these - out = append(out, e) +func injectEnv(environ []string, baseURL, authToken string) []string { + out := make([]string, 0, len(environ)+3) + out = append(out, environ...) + out = upsertEnv(out, "OPENAI_BASE_URL", baseURL) + out = upsertEnv(out, "OPENAI_API_BASE", baseURL) // older Python SDK + if authToken != "" { + out = upsertEnv(out, "OPENAI_API_KEY", authToken) } - // Append — last value wins in most runtimes - out = append(out, "OPENAI_BASE_URL="+baseURL) - out = append(out, "OPENAI_API_BASE="+baseURL) // older Python SDK return out } -func waitForReady(url string) { - for i := 0; i < 50; i++ { - resp, err := http.Get(url) +// newSessionToken produces a random 128-bit bearer token in unpadded base32 +// for per-wrap auth. It is only ever passed in-process and to the child via +// environment, never logged. +func newSessionToken() string { + var buf [16]byte + if _, err := rand.Read(buf[:]); err != nil { + // crypto/rand failure is unrecoverable; log.Fatalf still works here + // because obs is initialised before runWrap. + log.Fatalf("frugal: failed to generate session token: %v", err) + } + return "frugal-" + base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(buf[:]) +} + +func upsertEnv(environ []string, key, value string) []string { + prefix := key + "=" + for i, e := range environ { + if len(e) >= len(prefix) && e[:len(prefix)] == prefix { + environ[i] = prefix + value + return environ + } + } + return append(environ, prefix+value) +} + +// waitForReady polls the proxy /health endpoint with a tight per-call timeout. +// A misconfigured HTTP proxy env var or bad DNS would otherwise stall on the +// default client's 30s-plus timeout; cap the overall wait at ~2s so we fail +// fast and the wrapped command never inherits a half-started proxy. +func waitForReady(url string) error { + client := &http.Client{Timeout: 50 * time.Millisecond} + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + resp, err := client.Get(url) if err == nil { - resp.Body.Close() - return + _ = resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } } - time.Sleep(10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) } + return fmt.Errorf("proxy did not become ready within 2s") } diff --git a/cmd/frugal/wrap_test.go b/cmd/frugal/wrap_test.go new file mode 100644 index 0000000..d62fe9a --- /dev/null +++ b/cmd/frugal/wrap_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "strings" + "testing" +) + +func TestInjectEnv_OverridesExistingBaseURLsWithoutDuplicates(t *testing.T) { + env := []string{ + "PATH=/usr/bin", + "OPENAI_BASE_URL=http://old.local/v1", + "OPENAI_API_BASE=http://old.local/v1", + "OPENAI_API_KEY=sk-user-real-key", + } + + got := injectEnv(env, "http://127.0.0.1:8080/v1", "frugal-session-token") + + if countEnvKey(got, "OPENAI_BASE_URL") != 1 { + t.Fatalf("expected exactly one OPENAI_BASE_URL entry, got %d", countEnvKey(got, "OPENAI_BASE_URL")) + } + if countEnvKey(got, "OPENAI_API_BASE") != 1 { + t.Fatalf("expected exactly one OPENAI_API_BASE entry, got %d", countEnvKey(got, "OPENAI_API_BASE")) + } + if countEnvKey(got, "OPENAI_API_KEY") != 1 { + t.Fatalf("expected exactly one OPENAI_API_KEY entry, got %d", countEnvKey(got, "OPENAI_API_KEY")) + } + + if valueForEnvKey(got, "OPENAI_BASE_URL") != "http://127.0.0.1:8080/v1" { + t.Fatalf("OPENAI_BASE_URL not updated, got %q", valueForEnvKey(got, "OPENAI_BASE_URL")) + } + if valueForEnvKey(got, "OPENAI_API_BASE") != "http://127.0.0.1:8080/v1" { + t.Fatalf("OPENAI_API_BASE not updated, got %q", valueForEnvKey(got, "OPENAI_API_BASE")) + } + if valueForEnvKey(got, "OPENAI_API_KEY") != "frugal-session-token" { + t.Fatalf("OPENAI_API_KEY not replaced with session token, got %q", valueForEnvKey(got, "OPENAI_API_KEY")) + } +} + +func TestInjectEnv_AddsBaseURLsWhenMissing(t *testing.T) { + env := []string{"PATH=/usr/bin"} + got := injectEnv(env, "http://127.0.0.1:9090/v1", "") + + if valueForEnvKey(got, "OPENAI_BASE_URL") != "http://127.0.0.1:9090/v1" { + t.Fatalf("OPENAI_BASE_URL missing or incorrect") + } + if valueForEnvKey(got, "OPENAI_API_BASE") != "http://127.0.0.1:9090/v1" { + t.Fatalf("OPENAI_API_BASE missing or incorrect") + } + // Empty auth token means we must not inject a phantom OPENAI_API_KEY + // when the caller supplied one; callers in production always supply one. + if valueForEnvKey(got, "OPENAI_API_KEY") != "" { + t.Fatalf("OPENAI_API_KEY should not be set when authToken is empty") + } +} + +func countEnvKey(env []string, key string) int { + prefix := key + "=" + count := 0 + for _, e := range env { + if strings.HasPrefix(e, prefix) { + count++ + } + } + return count +} + +func valueForEnvKey(env []string, key string) string { + prefix := key + "=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return strings.TrimPrefix(e, prefix) + } + } + return "" +} diff --git a/config/CAPABILITIES.md b/config/CAPABILITIES.md new file mode 100644 index 0000000..011e982 --- /dev/null +++ b/config/CAPABILITIES.md @@ -0,0 +1,57 @@ +# Capability scores + +The `reasoning`, `coding`, `creative`, and `instruction_following` fields in +`config/models.yaml` drive routing. This document explains how those numbers +are grounded and how to refresh them. + +## Sources + +| Axis | Source | Notes | +|---|---|---| +| `reasoning` | [LiveBench](https://livebench.ai/) `reasoning` + `mathematics` average | Refreshed monthly. Normalize each provider's raw score into `[0.0, 1.0]` by dividing by LiveBench's reported top-of-leaderboard score at the same cutoff. | +| `coding` | [Aider polyglot benchmark](https://aider.chat/docs/leaderboards/) | Use the `pass_rate` column. Normalize by dividing by the highest pass_rate in the snapshot. | +| `creative` | Aider side-benchmark qualitative tiering | Top-tier flagship models get 0.90 – 0.95; mid-tier 0.70 – 0.85; nano / haiku-class 0.55 – 0.70. Tier boundaries are judgment calls; re-review when a new frontier drops. | +| `instruction_following` | Aider `instruction_adherence` scaled to `[0.0, 1.0]` | Tracks how reliably the model follows structured prompts without going off-script. | + +Every `ModelConfig.capabilities` block should include both: + +```yaml +source: livebench+aider +as_of: 2026-04-15 +``` + +## Refresh protocol + +Run this on a reliable cadence (monthly, or when a new frontier model ships): + +1. Pull the latest LiveBench leaderboard JSON for the month. +2. Pull Aider's `leaderboard.yml` for the same cutoff. +3. For each model in `config/models.yaml`: + - Compute `reasoning` and `coding` from the normalized benchmark scores. + - Assign the `creative` and `instruction_following` tier by eye, using the tiering rubric above. +4. Update `source:` / `as_of:` for every model touched. +5. Open a PR titled `chore(capabilities): refresh scores (livebench YYYY-MM, aider YYYY-MM)` and include a table of deltas in the description. + +`frugal sync` **does not** mutate capability scores. It only refreshes pricing, +context length, tool-use, and JSON-mode flags from `models.dev`. The scores are +the editorial product of this project and should move through code review. + +## Why not LMArena Elo? + +LMArena produces a single composite preference score per model. Frugal routes +on four axes because "best model for coding" and "best model for creative" are +often different picks; collapsing into one number loses the signal that lets +Frugal drop Opus-tier cost for mini-tier cost on routine tasks. + +## Why not MMLU / GPQA? + +MMLU and GPQA are saturated at the frontier — modern flagship models all score +within a few points of each other, so the axis loses discrimination power. +LiveBench re-rolls its problems monthly, and Aider measures real edit-loop +coding accuracy rather than multiple-choice recall. + +## Changing sources + +If you switch to a new benchmark suite, do it for every model in one PR, bump +the `source:` field for each, and document the mapping in this file. Mixed +sources across models make routing decisions hard to justify in review. diff --git a/config/models.yaml b/config/models.yaml index f2dcc7b..da10f9f 100644 --- a/config/models.yaml +++ b/config/models.yaml @@ -1,3 +1,5 @@ +# Capability scores are refreshed editorially — not by `frugal sync`. +# Methodology and refresh protocol: config/CAPABILITIES.md. providers: anthropic: api_key_env: ANTHROPIC_API_KEY @@ -13,7 +15,10 @@ providers: instruction_following: 0.78 tool_use: true json_mode: false + vision: true max_context: 189096 + source: approximate-tier-assignment + as_of: "2026-04-21" claude-opus-4-20250918: cost_per_1k_input: 0.005 cost_per_1k_output: 0.025 @@ -24,7 +29,10 @@ providers: instruction_following: 0.97 tool_use: true json_mode: false + vision: true max_context: 200000 + source: approximate-tier-assignment + as_of: "2026-04-21" claude-sonnet-4-20250514: cost_per_1k_input: 0.003 cost_per_1k_output: 0.015 @@ -35,7 +43,10 @@ providers: instruction_following: 0.94 tool_use: true json_mode: false + vision: true max_context: 200000 + source: approximate-tier-assignment + as_of: "2026-04-21" google: api_key_env: GOOGLE_API_KEY base_url: https://generativelanguage.googleapis.com @@ -50,7 +61,10 @@ providers: instruction_following: 0.72 tool_use: true json_mode: true + vision: true max_context: 1048576 + source: approximate-tier-assignment + as_of: "2026-04-21" gemini-2.5-flash: cost_per_1k_input: 0.0003 cost_per_1k_output: 0.0025 @@ -61,7 +75,10 @@ providers: instruction_following: 0.74 tool_use: true json_mode: true + vision: true max_context: 1048576 + source: approximate-tier-assignment + as_of: "2026-04-21" gemini-2.5-pro: cost_per_1k_input: 0.00125 cost_per_1k_output: 0.01 @@ -72,7 +89,10 @@ providers: instruction_following: 0.88 tool_use: true json_mode: true + vision: true max_context: 1048576 + source: approximate-tier-assignment + as_of: "2026-04-21" openai: api_key_env: OPENAI_API_KEY base_url: https://api.openai.com/v1 @@ -87,7 +107,10 @@ providers: instruction_following: 0.95 tool_use: true json_mode: true + vision: true max_context: 128000 + source: approximate-tier-assignment + as_of: "2026-04-21" gpt-4o-mini: cost_per_1k_input: 0.00015 cost_per_1k_output: 0.0006 @@ -98,7 +121,10 @@ providers: instruction_following: 0.8 tool_use: true json_mode: true + vision: true max_context: 128000 + source: approximate-tier-assignment + as_of: "2026-04-21" gpt-4.1: cost_per_1k_input: 0.002 cost_per_1k_output: 0.008 @@ -109,7 +135,10 @@ providers: instruction_following: 0.96 tool_use: true json_mode: true + vision: true max_context: 1047576 + source: approximate-tier-assignment + as_of: "2026-04-21" gpt-4.1-mini: cost_per_1k_input: 0.0004 cost_per_1k_output: 0.0016 @@ -120,7 +149,10 @@ providers: instruction_following: 0.82 tool_use: true json_mode: true + vision: true max_context: 1047576 + source: approximate-tier-assignment + as_of: "2026-04-21" gpt-4.1-nano: cost_per_1k_input: 0.0001 cost_per_1k_output: 0.0004 @@ -131,7 +163,10 @@ providers: instruction_following: 0.65 tool_use: true json_mode: true + vision: true max_context: 1047576 + source: approximate-tier-assignment + as_of: "2026-04-21" quality_thresholds: balanced: min_reasoning: 0.7 diff --git a/config/use_cases/code-dev.yaml b/config/use_cases/code-dev.yaml new file mode 100644 index 0000000..cdb0f7d --- /dev/null +++ b/config/use_cases/code-dev.yaml @@ -0,0 +1,40 @@ +id: code-dev +description: > + Code generation, debugging, code review, code explanation, and test + authoring. Chat-heavy workload that benefits from targeted doc/Stack + Overflow lookups via search but rarely needs reranking. The model's + coding-benchmark score (Aider, HumanEval-plus) drives routing more + than general reasoning. + +source: curated +as_of: "2026-04-21" +confidence: medium + +bundles: + high: + chat: gpt-4.1 + search: null # Ring 1b will route to Serper (Stack Overflow / docs) + rerank: null + reason: > + GPT-4.1 leads Aider-polyglot and holds up on long-context code + reading. Worth the premium on non-trivial refactors and debugging. + + balanced: + chat: gpt-4.1-mini + search: null + rerank: null + reason: > + 4.1-mini is the pragmatic default for day-to-day coding — strong + coding scores at mini-tier cost. Claude Sonnet 4 is comparable on + quality but measurably pricier per token. + + cost: + chat: gpt-4o-mini + search: null + rerank: null + reason: > + Good enough for single-file edits, linting-style fixes, and + formatting. Breaks down on multi-file reasoning; flag that to + callers via bundle `reason`. + +workload: config/workloads/use-cases/code-dev.yaml diff --git a/config/use_cases/factual-qa.yaml b/config/use_cases/factual-qa.yaml new file mode 100644 index 0000000..1620fbc --- /dev/null +++ b/config/use_cases/factual-qa.yaml @@ -0,0 +1,38 @@ +id: factual-qa +description: > + Short factual lookups, trivia, definitional questions, and one-turn + Q&A. The model mostly recalls or does shallow reasoning over a small + context. Expensive frontier models are wasted here — the bundle + deliberately picks mini/nano tiers. + +source: curated +as_of: "2026-04-21" +confidence: high + +bundles: + high: + chat: gpt-4o-mini + search: null # Ring 1b will route to a cheap SERP for temporal queries + rerank: null + reason: > + Even at the "high" tier this use case doesn't need frontier models. + 4o-mini scores within a couple of points of GPT-4o on TriviaQA at + a fraction of the cost. + + balanced: + chat: gpt-4.1-nano + search: null + rerank: null + reason: > + Nano is genuinely good on one-turn factual Q&A. Default for this + use case. + + cost: + chat: gemini-2.0-flash + search: null + rerank: null + reason: > + Flash is the absolute floor — low per-token cost, adequate on + factual recall. + +workload: config/workloads/use-cases/factual-qa.yaml diff --git a/config/use_cases/research-synthesis.yaml b/config/use_cases/research-synthesis.yaml new file mode 100644 index 0000000..b79fa91 --- /dev/null +++ b/config/use_cases/research-synthesis.yaml @@ -0,0 +1,45 @@ +id: research-synthesis +description: > + Long-form research questions requiring multi-source synthesis. Examples: + "summarize recent advances in X", "compare architectural approaches across + three systems", literature reviews, and analyst reports. Typically needs + long-context reasoning plus retrieval and reranking to keep prompt costs + under control. + +# Curation is author-judgment today. `frugal bench --use-case research-synthesis` +# produces the data that will eventually flip `source` to `benchmark-derived`. +source: curated +as_of: "2026-04-21" +confidence: medium + +# Each bundle declares the recommended (capability → model) map for a quality +# tier. Ring 1a populates only `chat`. `search` lights up in Ring 1b (web +# search capability) and `rerank` in Ring 1c. Null fields are intentionally +# left in place so the schema is stable as capabilities expand. +bundles: + high: + chat: claude-opus-4-20250918 + search: null + rerank: null + reason: > + Opus handles long context and multi-source reasoning best; the cost + premium is justified when the caller's work hinges on synthesis quality. + + balanced: + chat: claude-sonnet-4-20250514 + search: null + rerank: null + reason: > + Sonnet 4 is ~80% of Opus quality at 40% of the cost on long-form + research tasks. Default choice for the use case. + + cost: + chat: gpt-4o-mini + search: null + rerank: null + reason: > + Budget tier. Accepts reduced breadth and less nuanced synthesis for + ~90% lower per-request cost. Best paired with a stricter rerank + pass once Ring 1c lands. + +workload: config/workloads/use-cases/research-synthesis.yaml diff --git a/config/use_cases/structured-extraction.yaml b/config/use_cases/structured-extraction.yaml new file mode 100644 index 0000000..7cf35d9 --- /dev/null +++ b/config/use_cases/structured-extraction.yaml @@ -0,0 +1,42 @@ +id: structured-extraction +description: > + Free-text to JSON extraction, entity/field extraction, schema filling, + and format conversion. No toolchain needed — the bundle deliberately + leaves `search` and `rerank` null at every tier. The routing value comes + from picking the cheapest JSON-mode-capable model that reliably emits + schema-conformant output. + +source: curated +as_of: "2026-04-21" +confidence: high + +bundles: + high: + chat: gpt-4o-mini + search: null + rerank: null + reason: > + Even at the "high" tier this use case doesn't benefit from frontier + models — JSON-mode reliability saturates at mini-tier. 4o-mini is + the pick for schemas with nested/discriminated unions where + schema-conformant output matters more than tone. + + balanced: + chat: gemini-2.5-flash + search: null + rerank: null + reason: > + Gemini 2.5 Flash is reliable on JSON mode and a bit cheaper than + 4o-mini per token for the typical extraction request. Default + choice for this use case. + + cost: + chat: gemini-2.0-flash + search: null + rerank: null + reason: > + The absolute floor. Accepts occasional schema drift for the lowest + per-request cost; recommend retrying with `balanced` on + validation-failure. + +workload: config/workloads/use-cases/structured-extraction.yaml diff --git a/config/workloads/starter.yaml b/config/workloads/starter.yaml new file mode 100644 index 0000000..d76b128 --- /dev/null +++ b/config/workloads/starter.yaml @@ -0,0 +1,150 @@ +name: starter +description: > + 20-problem bootstrap benchmark spanning structured extraction, arithmetic, + classification, factual recall, and code explanation. All problems are + original and license-free so the workload can live in this repo without + derivative-work concerns. Reproduces in under $0.50 with a gpt-4o baseline. +baseline: gpt-4o + +problems: + # ----- structured extraction (JSON output, schema check) --------------- + - id: extract-email-01 + prompt: | + Extract the name and email as JSON with keys "name" and "email". + Input: "Please forward the invoice to Jane Ortiz at jane.ortiz@acme.co." + json_mode: true + max_tokens: 80 + expected_keys: [name, email] + + - id: extract-phone-01 + prompt: | + Return JSON with keys "contact" and "phone" for: "Call Raj at 415-555-0199 before 4pm." + json_mode: true + max_tokens: 80 + expected_keys: [contact, phone] + + - id: extract-order-01 + prompt: | + Return JSON with keys "order_id", "quantity", and "sku" for this line item: + "Order #A-774: 12 × SKU-PANTS-Navy-34R shipped 2026-02-10." + json_mode: true + max_tokens: 100 + expected_keys: [order_id, quantity, sku] + + - id: extract-date-01 + prompt: | + Extract the date in ISO 8601 form. Reply with JSON {"iso_date": "..."} only. + Input: "Ship on the third of March, twenty twenty-six." + json_mode: true + max_tokens: 60 + expected_keys: [iso_date] + + # ----- arithmetic (numeric, zero or tiny tolerance) -------------------- + - id: math-mul-01 + prompt: "What is 42 × 17? Give only the number." + max_tokens: 20 + expected_number: 714 + tolerance: 0 + + - id: math-div-01 + prompt: "What is 1764 ÷ 42? Give only the number." + max_tokens: 20 + expected_number: 42 + tolerance: 0 + + - id: math-percent-01 + prompt: "What is 18% of 250? Give only the number." + max_tokens: 20 + expected_number: 45 + tolerance: 0 + + - id: math-compound-01 + prompt: | + A basket costs $80 after a 20% discount. What was the original price? + Give only the number, no units. + max_tokens: 20 + expected_number: 100 + tolerance: 0 + + # ----- classification (exact label match) ------------------------------ + - id: classify-sentiment-pos + system: | + Classify the sentiment of the review as exactly one word: positive, negative, or neutral. + Respond with only that single word. + prompt: "The pasta was perfectly al dente and the service felt unhurried." + max_tokens: 5 + expected_equals: positive + + - id: classify-sentiment-neg + system: | + Classify the sentiment of the review as exactly one word: positive, negative, or neutral. + Respond with only that single word. + prompt: "The battery lasted 90 minutes and the packaging arrived crushed." + max_tokens: 5 + expected_equals: negative + + - id: classify-intent-refund + system: | + Classify the customer intent as exactly one of: refund, shipping, technical, other. + Respond with only that single word. + prompt: "My package showed up but the headphones are dead on arrival — can I send them back for my money?" + max_tokens: 5 + expected_equals: refund + + - id: classify-intent-shipping + system: | + Classify the customer intent as exactly one of: refund, shipping, technical, other. + Respond with only that single word. + prompt: "Where is my order? The tracking number hasn't updated in four days." + max_tokens: 5 + expected_equals: shipping + + # ----- factual recall (substring) -------------------------------------- + - id: fact-capital-france + prompt: "What is the capital of France? Answer in one word." + max_tokens: 10 + expected_contains: Paris + + - id: fact-tallest-mountain + prompt: "What is the name of the tallest mountain on Earth above sea level?" + max_tokens: 20 + expected_contains: Everest + + - id: fact-sql-keyword + prompt: "Which SQL keyword removes duplicate rows from a SELECT result? Answer in one word." + max_tokens: 10 + expected_contains: DISTINCT + case_fold: true + + - id: fact-http-status + prompt: "Which HTTP status code indicates that a resource was not found?" + max_tokens: 20 + expected_contains: "404" + + # ----- code explanation (contains_all keywords, case-fold) ------------- + - id: explain-quicksort + prompt: | + In two to four sentences, explain how quicksort works. + max_tokens: 220 + expected_contains_all: [pivot, partition] + case_fold: true + + - id: explain-big-o-sort + prompt: "What is the average-case time complexity of mergesort? Just the notation." + max_tokens: 20 + expected_contains: "n log n" + case_fold: true + + - id: explain-rest-vs-rpc + prompt: | + In three to five sentences, contrast REST and RPC as API styles. + max_tokens: 300 + expected_contains_all: [resource, verb] + case_fold: true + + - id: explain-db-index + prompt: | + In two sentences, explain why a B-tree index speeds up equality lookups. + max_tokens: 200 + expected_contains_all: [tree, balanced] + case_fold: true diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..ef02374 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,85 @@ +# frugal.sh landing page + +Static single-page site for `https://frugal.sh/`. Ships with one HTML file, one +CSS file, one SVG favicon, a `_redirects` rule that rewrites `/install` to the +installer script in this repo, and an `_headers` file that sets CSP + cache +headers. Total payload is under 30 KB. + +> The HTML, `_redirects`, and sitemap all point at +> `github.com/brainsparker/frugal` — the current repo slug. If you rename the +> repo to `frugalsh/frugal` (to match `install.sh`'s `REPO="frugalsh/frugal"` +> and the README's other references), do a find/replace across this directory +> before the next deploy. `install.sh` itself also needs its `REPO=` line +> updated so the API call for the latest release targets the right repo. + +Designed for Cloudflare Pages (primary) and GitHub Pages (fallback). + +## Deploy on Cloudflare Pages (recommended) + +Cloudflare Pages is the right choice for `frugal.sh` because `.sh` domains +default to Cloudflare DNS anyway, and Pages natively supports the `/install` +200-rewrite via `_redirects` — critical for `curl … | sh` to work without a +3xx hop. + +1. Push `docs/` to GitHub on the `main` branch. +2. Cloudflare dashboard → **Workers & Pages → Create → Connect to Git**. +3. Select this repo. Configure: + - **Production branch:** `main` + - **Build command:** *(leave empty — no build step)* + - **Build output directory:** `docs` + - **Root directory:** `/` (default) +4. Deploy. The first build publishes to `.pages.dev`. +5. Back in the project → **Custom domains → Set up a custom domain** → + `frugal.sh` and also `www.frugal.sh`. Cloudflare will add the CNAME + records automatically because the zone is already on Cloudflare. +6. After DNS propagates, verify: + - `curl -sS https://frugal.sh/ | head` returns the HTML + - `curl -sSL https://frugal.sh/install | head` returns `#!/usr/bin/env bash` + - `curl -I https://frugal.sh/install` shows + `content-type: text/x-shellscript` and a 200 status (not a 301/302) + +The `_headers` file pins a strict CSP, `HSTS` with preload, and short +cache on the install script so operators always get the latest checksum- +verified bytes. + +## Deploy on GitHub Pages (fallback) + +Use this if you'd rather not bring Cloudflare Pages into the mix. Downsides: +GH Pages doesn't support the `/install` rewrite cleanly, and the `_headers` +file is ignored. You'd need to either commit the install script into this +directory (as `install.sh`) or accept that users would pull it from a +redirect to `raw.githubusercontent.com`. + +1. Repo → **Settings → Pages**. +2. Source: **Deploy from a branch**. Branch: `main`, folder: `/docs`. +3. Custom domain: `frugal.sh`. Save; check **Enforce HTTPS** once the cert + provisions. +4. Copy `../install.sh` into `docs/install.sh` (one-off; tag-updated) so + `https://frugal.sh/install.sh` works. If you want the shorter + `/install` URL you'd need to add a meta-refresh HTML stub, which doesn't + pipe cleanly into `sh`. For that reason, Cloudflare Pages is preferred. + +## Local preview + +Any static server works. Built-in Python is fine: + +```bash +cd docs +python3 -m http.server 8000 +# open http://localhost:8000/ +``` + +## Editing + +- `index.html` — content. One file so the whole page is easy to review. +- `styles.css` — tokens at the top of the file (`:root`), dark-mode default + with a `prefers-color-scheme: light` override. Change the accent by + editing `--accent`. +- `favicon.svg` — rounded square with a monospace "f". Change the fill to + retheme. +- `_redirects` — add convenience redirects here (e.g. `/talk-2026 → …`). +- `_headers` — cache and CSP policy. + +Keep this page command-first. If you find yourself reaching for a third-party +font or an analytics script, push back — the page's value prop is that it +loads before the user has time to regret clicking. diff --git a/docs/_headers b/docs/_headers new file mode 100644 index 0000000..abcfde3 --- /dev/null +++ b/docs/_headers @@ -0,0 +1,29 @@ +# Security and performance headers for every response Cloudflare Pages serves. +# Tightened CSP because the site is static with one inline script block (the +# copy-to-clipboard helper) and no third-party assets. + +/* + X-Content-Type-Options: nosniff + X-Frame-Options: DENY + Referrer-Policy: strict-origin-when-cross-origin + Permissions-Policy: accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=() + Strict-Transport-Security: max-age=63072000; includeSubDomains; preload + Content-Security-Policy: default-src 'self'; img-src 'self' data:; style-src 'self'; script-src 'self' 'unsafe-inline'; base-uri 'self'; form-action 'self'; frame-ancestors 'none'; object-src 'none' + +# The install script itself should never be cached stale — people expect a +# fresh version on curl-pipe-sh. +/install + Cache-Control: public, max-age=300, must-revalidate + Content-Type: text/x-shellscript; charset=utf-8 + +/install.sh + Cache-Control: public, max-age=300, must-revalidate + Content-Type: text/x-shellscript; charset=utf-8 + +# Long-cache the stylesheet + favicon since they're content-addressed by +# filename in index.html; bump filenames if you ever change contents. +/styles.css + Cache-Control: public, max-age=3600, stale-while-revalidate=86400 + +/favicon.svg + Cache-Control: public, max-age=604800 diff --git a/docs/_redirects b/docs/_redirects new file mode 100644 index 0000000..c8008b0 --- /dev/null +++ b/docs/_redirects @@ -0,0 +1,14 @@ +# Serve the install script at https://frugal.sh/install. The script lives at +# docs/install.sh so Pages serves it directly; we rewrite /install to +# /install.sh so `curl -fsSL https://frugal.sh/install | sh` gets the script +# bytes with a 200 (not a 302 that would surface an uglier URL on slow links). +# +# An external-URL rewrite can't substitute here — the repo is private, so +# raw.githubusercontent.com would 404 for anonymous curl. +/install /install.sh 200 + +# Convenience redirects so links from talks/blogs land on the canonical docs. +/github https://github.com/brainsparker/frugal 301 +/docs https://github.com/brainsparker/frugal#readme 301 +/security https://github.com/brainsparker/frugal/blob/main/SECURITY.md 301 +/license https://github.com/brainsparker/frugal/blob/main/LICENSE 301 diff --git a/docs/favicon.svg b/docs/favicon.svg new file mode 100644 index 0000000..6a39e63 --- /dev/null +++ b/docs/favicon.svg @@ -0,0 +1,6 @@ + diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..de8919b --- /dev/null +++ b/docs/index.html @@ -0,0 +1,382 @@ + + + + + + frugal — route every prompt to the cheapest model + toolchain bundle + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + frugal + + +
+ +
+
+ +

+ Route every prompt to the cheapest model + toolchain + bundle that won't compromise quality. +

+

+ You write a prompt. Frugal classifies it, picks the cheapest (model + toolchain) + bundle that clears the quality bar, and sends it there. Research, code, marketing, + extraction — whatever the job — every prompt gets routed through the same proxy. + No account. No code changes. One command. +

+ +
+ +
$ curl -fsSL https://frugal.sh/install | bash
+ +
+ +
+
$ frugal python my_app.py
+# frugal starts a proxy, sets OPENAI_BASE_URL, runs your command,
+# and shuts down when it exits. Your app doesn't change.
+ +
+ +
+ No account + No code changes + Keys stay local + Signed releases +
+ + +
+ +
+

The insight

+

+ You're paying for capability you don't use on 60–80% of your AI calls — + and when you do need capability, a bare chat model is rarely the answer. + Legal research needs a strong reasoner paired with good web search. A code-dev loop + needs Claude paired with code-aware retrieval. Structured extraction needs Haiku + with nothing else. Frugal classifies each request by use case and + routes to the cheapest (model + toolchain) bundle that clears the quality bar. +

+
+
+
Use-case classifier
+
Detects the use case — research synthesis, code dev, factual Q&A, + structured extraction, and more — plus the capabilities required + (chat, search, rerank, extract, browse).
+
+
+
Bundle router
+
Per use case and quality tier, picks the cheapest + (model + toolchain) bundle that clears every threshold. + One proxy, one service, one routing decision.
+
+
+
Eval as source of truth
+
Use-case benchmarks decide which bundles win. Routing isn't opinion — + it's what the data says wins for your workload, rerun on every release.
+
+
+
+ +
+

How it works

+
frugal python app.py
+       │
+       ├─ starts proxy on a free port
+       ├─ injects OPENAI_BASE_URL into your command's environment
+       ├─ classifies each request (use case + required capabilities)
+       ├─ picks the cheapest (model + toolchain) bundle that clears the bar
+       ├─ calls bundled tools — web search, rerank, extract, browse — as needed
+       └─ shuts down proxy when your command exits
+
+ +
+

Works with any OpenAI-compatible SDK

+

+ Frugal speaks the OpenAI chat-completions API. Point your existing SDK at the + proxy and nothing else changes. Toolchain capabilities — search, rerank, + extract, browse — are exposed on the same proxy under their own endpoints, + so your code stays in control and Frugal routes each capability to the + cheapest provider that clears the bar for your use case. +

+
+
+
Python
+
# unchanged from before frugal
+from openai import OpenAI
+client = OpenAI()
+resp = client.chat.completions.create(
+    model="auto",
+    messages=[{"role": "user", "content": "hi"}],
+)
+
+
+
Node
+
// unchanged from before frugal
+import OpenAI from "openai";
+const client = new OpenAI();
+const resp = await client.chat.completions.create({
+  model: "auto",
+  messages: [{ role: "user", content: "hi" }],
+});
+
+
+
curl
+
curl "$OPENAI_BASE_URL/chat/completions" \
+  -H "Content-Type: application/json" \
+  -H "Authorization: Bearer $OPENAI_API_KEY" \
+  -d '{"model":"auto","messages":[{"role":"user","content":"hi"}]}'
+
+
+
Go
+
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
+resp, _ := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
+    Model: "auto",
+    Messages: []openai.ChatCompletionMessage{{
+        Role: openai.ChatMessageRoleUser, Content: "hi",
+    }},
+})
+
+
+
+ +
+

Three quality tiers — pick per request

+

Send X-Frugal-Quality: cost for a quick extraction. + Send high for your agent's planner step. Default balanced for everything else.

+
+
+

high

+

Top-tier models only. Use for planners, complex reasoning, novel code.

+ X-Frugal-Quality: high +
+ +
+

cost

+

Cheapest viable model. Classification, extraction, simple summaries.

+ X-Frugal-Quality: cost +
+
+
+ +
+

What routes today

+

Model pricing synced from models.dev on every startup. + Extend by editing ~/.frugal/config/models.yaml. Toolchain capabilities + roll out one at a time, each gated by the eval harness — web search and reranking + first, content extraction and browser next.

+ +

Model providers

+
+
+

OpenAI

+

GPT-4o · GPT-4o-mini · GPT-4.1 · GPT-4.1-mini · GPT-4.1-nano

+
+
+

Anthropic

+

Claude Opus 4 · Claude Sonnet 4 · Claude Haiku 3.5

+
+
+

Google

+

Gemini 2.5 Pro · Gemini 2.5 Flash · Gemini 2.0 Flash

+
+
+ +

Toolchain capabilities

+
+
+

Web search next

+

Route each search to the cheapest provider that clears recall & freshness for the use case.

+
+
+

Reranking next

+

Cross-encoder rerank of retrieved passages. Ships after web search.

+
+
+

Content extraction roadmap

+

Clean text & structured content from URLs. Bundled with search for factual Q&A.

+
+
+

Browser use roadmap

+

Headless browsing for interactive sites. Bundled for use cases search can't reach.

+
+
+
+ +
+

Routed per use case

+

+ A use case is a named class of work. Each one ships with a labeled benchmark workload + and a curated bundle — the (capability → provider) map the eval says wins. Call + /v1/bundles/{use-case} to see what's picked and why. +

+
+
+

research-synthesis

+

Chat + web search + rerank. Strong reasoner, broad recall, tight context.

+
+
+

code-dev

+

Chat with code-aware retrieval. Pairs a strong coder model with targeted search.

+
+
+

factual-qa

+

Chat + search + extraction. Cheapest model that holds grounding on the benchmark wins.

+
+
+

structured-extraction

+

Chat only, JSON-mode. No toolchain needed — the smallest model that passes the schema.

+
+
+
+ +
+

Install

+

One binary, ~10 MB. Detects your API keys. Adds frugal to + your PATH. Release artifacts are signed with cosign; the installer + verifies the checksum before moving the binary into place.

+ +
+
$ curl -fsSL https://frugal.sh/install | bash
+ +
+ +

From source

+
+
$ git clone https://github.com/brainsparker/frugal.git && cd frugal && make build
+
+ +

Run as a server

+
+
$ export FRUGAL_AUTH_TOKEN=$(openssl rand -hex 16)
+$ frugal serve
+$ export OPENAI_BASE_URL=http://localhost:8080/v1
+
+ +

+ When binding outside 127.0.0.1, Frugal refuses to start without an + auth token (override with FRUGAL_ALLOW_UNAUTH=1). Prometheus metrics + are served at /metrics. Full env-var reference in the + README. +

+
+ +
+

Built to self-host

+
+
+
Your keys stay local
+
Provider API keys never leave your machine. Frugal reads them from the + environment and forwards requests upstream; there is no control plane.
+
+
+
Signed releases
+
Every binary is cosign-signed (keyless, GitHub OIDC). The installer + verifies the checksum and, when cosign is present, the signature + before executing anything.
+
+
+
Observable by default
+
Structured logs, X-Request-ID propagation, Prometheus + /metrics, and /v1/routing/explain to see exactly + which bundle was picked — model, toolchain, and why.
+
+
+
Source-available, BUSL 1.1
+
Self-hosted and internal commercial use is permitted. Each version + converts to Apache 2.0 four years after release.
+
+
+
+
+ + + + + + diff --git a/docs/install.sh b/docs/install.sh new file mode 100755 index 0000000..ac99d63 --- /dev/null +++ b/docs/install.sh @@ -0,0 +1,372 @@ +#!/usr/bin/env bash +# +# frugal.sh installer +# Usage: +# curl -fsSL https://frugal.sh/install | bash +# curl -fsSL https://frugal.sh/install | bash -s uninstall +# +# Pipe to `bash`, not `sh`. On Ubuntu/Debian /bin/sh is dash and doesn't +# support `set -o pipefail` or other bash-isms used below. The shebang is +# ignored when the script is streamed from stdin, so the interpreter is +# whatever you pipe to. +# +# Env vars: +# FRUGAL_VERSION Pin a specific release tag (e.g. v0.1.0). Default: latest. +# FRUGAL_INSTALL_DIR Install root. Default: $HOME/.frugal +# FRUGAL_YES Non-interactive. Skips the confirmation prompt. +# GITHUB_TOKEN Optional. When set, the releases/latest API call is +# authenticated (5000/hr cap instead of 60/hr). Useful +# in CI where runner IPs are shared. +# +# Exit codes: +# 0 success +# 2 unsupported platform +# 3 network / upstream failure +# 4 verification (checksum or signature) failed +# 5 local state / user-aborted + +set -euo pipefail + +readonly EXIT_UNSUPPORTED=2 +readonly EXIT_NETWORK=3 +readonly EXIT_VERIFY=4 +readonly EXIT_LOCAL=5 + +readonly REPO="brainsparker/frugal" +readonly PINNED_VERSION="${FRUGAL_VERSION:-}" +readonly INSTALL_DIR="${FRUGAL_INSTALL_DIR:-$HOME/.frugal}" +readonly BIN_DIR="$INSTALL_DIR/bin" +readonly CONFIG_DIR="$INSTALL_DIR/config" + +# Exact-match markers for the shell rc block. Uninstall deletes everything +# between (and including) these lines. Do not change these strings without +# considering existing users — the uninstall path depends on matching them. +readonly RC_BEGIN="# >>> frugal.sh >>>" +readonly RC_END="# <<< frugal.sh <<<" + +# ---- UI ---- + +info() { printf "\033[1;34m==>\033[0m %s\n" "$1"; } +ok() { printf "\033[1;32m ✓\033[0m %s\n" "$1"; } +warn() { printf "\033[1;33m !\033[0m %s\n" "$1"; } +fail() { printf "\033[1;31m ✗\033[0m %s\n" "$1" >&2; exit "${2:-1}"; } + +# ---- Platform detection ---- + +detect_platform() { + local os arch + os="$(uname -s | tr '[:upper:]' '[:lower:]')" + arch="$(uname -m)" + + case "$arch" in + x86_64|amd64) arch="amd64" ;; + arm64|aarch64) arch="arm64" ;; + *) fail "unsupported architecture: $arch" "$EXIT_UNSUPPORTED" ;; + esac + + case "$os" in + linux) echo "linux-${arch}" ;; + darwin) echo "darwin-${arch}" ;; + *) fail "unsupported OS: $os (supported: macOS, Linux)" "$EXIT_UNSUPPORTED" ;; + esac +} + +# ---- Network helpers ---- + +http_get() { + # Fetch URL to stdout. Loudly on any non-2xx or connection error. + local url="$1" + if command -v curl >/dev/null 2>&1; then + curl -fsSL "$url" || fail "failed to fetch $url" "$EXIT_NETWORK" + elif command -v wget >/dev/null 2>&1; then + wget -qO- "$url" || fail "failed to fetch $url" "$EXIT_NETWORK" + else + fail "curl or wget is required" "$EXIT_NETWORK" + fi +} + +http_download() { + local url="$1" dest="$2" + if command -v curl >/dev/null 2>&1; then + curl -fsSL "$url" -o "$dest" || fail "failed to download $url" "$EXIT_NETWORK" + elif command -v wget >/dev/null 2>&1; then + wget -qO "$dest" "$url" || fail "failed to download $url" "$EXIT_NETWORK" + else + fail "curl or wget is required" "$EXIT_NETWORK" + fi +} + +# ---- Version resolution ---- + +resolve_version() { + if [ -n "$PINNED_VERSION" ]; then + echo "$PINNED_VERSION" + return + fi + local api_url="https://api.github.com/repos/${REPO}/releases/latest" + local json tag + # Unauthenticated GitHub API requests are rate-limited to 60/hour per IP. + # Shared CI runner pools blow that cap easily. Honour GITHUB_TOKEN when + # present (bumps the cap to 5000/hour) — silent no-op for end users. + if [ -n "${GITHUB_TOKEN:-}" ] && command -v curl >/dev/null 2>&1; then + json="$(curl -fsSL -H "Authorization: Bearer $GITHUB_TOKEN" "$api_url")" \ + || fail "failed to fetch $api_url" "$EXIT_NETWORK" + else + json="$(http_get "$api_url")" + fi + if command -v jq >/dev/null 2>&1; then + tag="$(printf '%s' "$json" | jq -r '.tag_name // empty')" + else + # Strict anchored regex; fails loudly if the JSON shape shifts so we + # never silently install the wrong version. + tag="$(printf '%s' "$json" | sed -nE 's/.*"tag_name"[[:space:]]*:[[:space:]]*"([^"]+)".*/\1/p' | head -n1)" + fi + [ -n "$tag" ] || fail "could not resolve latest version (API response missing tag_name)" "$EXIT_NETWORK" + echo "$tag" +} + +# ---- Checksum ---- + +sha256_check() { + # Verify that a file matches its line in a SHA256SUMS file. + # Linux has sha256sum; macOS has shasum -a 256. Both accept -c on stdin. + local file="$1" sums="$2" base tool + base="$(basename "$file")" + + if command -v sha256sum >/dev/null 2>&1; then + tool=(sha256sum -c -) + elif command -v shasum >/dev/null 2>&1; then + tool=(shasum -a 256 -c -) + else + fail "no sha256 tool found (need sha256sum or shasum)" "$EXIT_VERIFY" + fi + + ( + cd "$(dirname "$file")" && + grep " ${base}\$" "$sums" | "${tool[@]}" >/dev/null + ) || fail "sha256 mismatch for $base" "$EXIT_VERIFY" +} + +# ---- Shell rc editing ---- + +detect_shell_rc() { + # Prefer the rc matching the current login shell. Falls back to the first + # existing rc file. Returns empty if none match (caller warns and skips). + case "${SHELL:-}" in + */zsh) echo "$HOME/.zshrc"; return ;; + */bash) echo "$HOME/.bashrc"; return ;; + esac + for rc in "$HOME/.zshrc" "$HOME/.bashrc" "$HOME/.bash_profile"; do + [ -f "$rc" ] && { echo "$rc"; return; } + done +} + +remove_rc_block() { + local rc="$1" + [ -f "$rc" ] || return 0 + grep -qxF "$RC_BEGIN" "$rc" || return 0 + # Portable block-delete: awk handles BSD/GNU sed differences. + awk -v b="$RC_BEGIN" -v e="$RC_END" ' + $0 == b { skip = 1; next } + $0 == e { skip = 0; next } + !skip + ' "$rc" > "$rc.frugal.tmp" && mv "$rc.frugal.tmp" "$rc" +} + +write_rc_block() { + local rc="$1" + remove_rc_block "$rc" + # >> creates the file if it doesn't exist (fresh-Mac with no ~/.zshrc case). + { + echo "" + echo "$RC_BEGIN" + echo "# Added by frugal.sh installer. Remove this block to uninstall PATH." + echo "export PATH=\"$BIN_DIR:\$PATH\"" + echo "export FRUGAL_CONFIG=\"$CONFIG_DIR/models.yaml\"" + echo "$RC_END" + } >> "$rc" +} + +# ---- Uninstall ---- + +uninstall() { + info "uninstalling frugal.sh" + + if [ -d "$INSTALL_DIR" ]; then + # Guardrail: only rm -rf paths that look like a Frugal install dir. + # Belt-and-suspenders against a mis-set FRUGAL_INSTALL_DIR. + case "$INSTALL_DIR" in + "$HOME/.frugal"|*/.frugal|*/frugal) + rm -rf "$INSTALL_DIR" + ok "removed $INSTALL_DIR" + ;; + *) + warn "refusing to remove unexpected INSTALL_DIR: $INSTALL_DIR" + warn "remove it by hand if you meant to" + ;; + esac + fi + + for rc in "$HOME/.zshrc" "$HOME/.bashrc" "$HOME/.bash_profile"; do + if [ -f "$rc" ] && grep -qxF "$RC_BEGIN" "$rc"; then + remove_rc_block "$rc" + ok "cleaned $rc" + fi + done + + echo + echo "frugal.sh uninstalled." + exit 0 +} + +# ---- Install ---- + +main() { + if [ "${1:-}" = "uninstall" ]; then + uninstall + fi + + info "installing frugal.sh — the open-source AI toolchain cost optimizer" + echo + + local platform version + platform="$(detect_platform)" + ok "detected platform: $platform" + + info "resolving version..." + version="$(resolve_version)" + ok "target version: $version" + + local shell_config + shell_config="$(detect_shell_rc || true)" + + # Show what's about to happen. Interactive sessions get a prompt; + # FRUGAL_YES=1 and non-TTY runs (e.g. CI, curl-pipe-sh) skip it. + echo + echo "This installer will:" + echo " * install frugal $version to $BIN_DIR/frugal" + echo " * write a marker block to ${shell_config:-} for PATH + FRUGAL_CONFIG" + echo " * leave default config at $CONFIG_DIR/models.yaml" + if [ -t 0 ] && [ "${FRUGAL_YES:-}" != "1" ]; then + printf "Proceed? [Y/n] " + local answer + read -r answer SHA256SUMS -> binary hash -> binary. + # Cosign is preferred; when it's not installed we keep installing (don't + # block first-time users on a new dependency) but say so loudly. + if command -v cosign >/dev/null 2>&1; then + http_download "${base}/SHA256SUMS.sig" "$tmpdir/SHA256SUMS.sig" + cosign verify-blob \ + --bundle "$tmpdir/SHA256SUMS.sig" \ + --certificate-identity-regexp "https://github.com/${REPO}/.github/workflows/release.yml@refs/tags/" \ + --certificate-oidc-issuer https://token.actions.githubusercontent.com \ + "$tmpdir/SHA256SUMS" >/dev/null \ + || fail "cosign signature verification failed for SHA256SUMS" "$EXIT_VERIFY" + ok "cosign signature verified" + else + warn "cosign not found — signature check skipped" + warn "install cosign to enable: https://docs.sigstore.dev/cosign/installation/" + fi + + sha256_check "$tmpdir/$artifact" "$tmpdir/SHA256SUMS" + ok "checksum verified" + + # Atomic promotion: one mv, not a copy + chmod dance. + chmod +x "$tmpdir/$artifact" + mv "$tmpdir/$artifact" "$BIN_DIR/frugal" + ok "installed frugal $version to $BIN_DIR/frugal" + + # Default config: fetch only if missing so re-runs don't clobber edits. + if [ ! -f "$CONFIG_DIR/models.yaml" ]; then + info "downloading default model config..." + http_download "https://raw.githubusercontent.com/${REPO}/main/config/models.yaml" \ + "$CONFIG_DIR/models.yaml" + ok "default config saved to $CONFIG_DIR/models.yaml" + else + ok "config already present at $CONFIG_DIR/models.yaml (kept)" + fi + + # Shell rc wiring. + if [ -n "$shell_config" ]; then + write_rc_block "$shell_config" + ok "shell config updated: $shell_config" + else + warn "no shell rc file found; add this to your shell profile:" + echo " export PATH=\"$BIN_DIR:\$PATH\"" + echo " export FRUGAL_CONFIG=\"$CONFIG_DIR/models.yaml\"" + fi + + # Export for this process so the smoke test below finds the binary. + export PATH="$BIN_DIR:$PATH" + export FRUGAL_CONFIG="$CONFIG_DIR/models.yaml" + + # Post-install smoke test: if --version doesn't respond, something's off + # even if every prior step reported success (corrupt file on disk, wrong + # arch artifact, exec bit stripped by a weird umask, etc). + if "$BIN_DIR/frugal" --version >/dev/null 2>&1; then + ok "smoke test: frugal --version OK" + else + fail "smoke test failed: $BIN_DIR/frugal --version did not exit cleanly" "$EXIT_VERIFY" + fi + + # Key detection — informational only. + echo + info "detecting provider API keys..." + local keys=0 + [ -n "${OPENAI_API_KEY:-}" ] && { ok "OPENAI_API_KEY found"; keys=$((keys + 1)); } + [ -n "${ANTHROPIC_API_KEY:-}" ] && { ok "ANTHROPIC_API_KEY found"; keys=$((keys + 1)); } + [ -n "${GOOGLE_API_KEY:-}" ] && { ok "GOOGLE_API_KEY found"; keys=$((keys + 1)); } + if [ "$keys" -eq 0 ]; then + warn "no provider API keys in environment" + echo " Set at least one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY" + fi + + echo + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo + echo " frugal.sh installed" + echo + echo " Start the proxy:" + echo " frugal" + echo + echo " Point your app at it:" + echo " export OPENAI_BASE_URL=http://localhost:8080/v1" + echo + echo " Route by use case:" + echo " curl -H 'X-Frugal-Use-Case: research-synthesis' ..." + echo + echo " Uninstall:" + echo " curl -fsSL https://frugal.sh/install | bash -s uninstall" + echo + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + echo +} + +main "$@" diff --git a/docs/robots.txt b/docs/robots.txt new file mode 100644 index 0000000..c6a6176 --- /dev/null +++ b/docs/robots.txt @@ -0,0 +1,4 @@ +User-agent: * +Allow: / + +Sitemap: https://frugal.sh/sitemap.xml diff --git a/docs/sitemap.xml b/docs/sitemap.xml new file mode 100644 index 0000000..96646b7 --- /dev/null +++ b/docs/sitemap.xml @@ -0,0 +1,8 @@ + + + + https://frugal.sh/ + weekly + 1.0 + + diff --git a/docs/styles.css b/docs/styles.css new file mode 100644 index 0000000..4bbe7ba --- /dev/null +++ b/docs/styles.css @@ -0,0 +1,502 @@ +/* Reset + tokens --------------------------------------------------------- */ + +*, *::before, *::after { box-sizing: border-box; } + +html { -webkit-text-size-adjust: 100%; } + +body { + margin: 0; + font-family: ui-sans-serif, system-ui, -apple-system, "Segoe UI", Roboto, + "Helvetica Neue", Arial, sans-serif; + font-size: 16px; + line-height: 1.6; + color: var(--fg); + background: var(--bg); + -webkit-font-smoothing: antialiased; + text-rendering: optimizeLegibility; +} + +:root { + --bg: #0a0a0b; + --surface: #141416; + --surface-2: #1a1a1d; + --border: #2a2a2d; + --fg: #e7e7e9; + --dim: #9ca3af; + --accent: #10b981; + --accent-soft: rgba(16, 185, 129, 0.12); + --danger: #f87171; + --yellow: #facc15; + --max: 68rem; + --radius: 10px; + --mono: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Monaco, Consolas, monospace; +} + +@media (prefers-color-scheme: light) { + :root { + --bg: #fafafa; + --surface: #ffffff; + --surface-2: #f4f4f5; + --border: #e4e4e7; + --fg: #18181b; + --dim: #52525b; + --accent: #059669; + --accent-soft: rgba(5, 150, 105, 0.08); + --danger: #dc2626; + --yellow: #eab308; + } +} + +::selection { background: var(--accent); color: #0a0a0b; } + +/* Skip link --------------------------------------------------------------- */ + +.skip { + position: absolute; left: 0; top: -40px; + background: var(--accent); color: #0a0a0b; + padding: 8px 12px; border-radius: 0 0 var(--radius) 0; + text-decoration: none; font-weight: 600; +} +.skip:focus { top: 0; outline: none; } + +/* Top nav ----------------------------------------------------------------- */ + +.top { + display: flex; align-items: center; justify-content: space-between; + padding: 20px 24px; + max-width: var(--max); margin: 0 auto; + gap: 24px; +} +.brand { + display: inline-flex; align-items: center; gap: 10px; + color: var(--fg); text-decoration: none; + font-weight: 600; letter-spacing: -0.01em; +} +.brand-mark { + display: inline-flex; align-items: center; justify-content: center; + width: 28px; height: 28px; border-radius: 7px; + background: var(--accent); color: #0a0a0b; + font-family: var(--mono); font-weight: 700; font-size: 16px; + line-height: 1; +} +.brand-mark.small { width: 22px; height: 22px; font-size: 13px; border-radius: 6px; } +.brand-name { font-family: var(--mono); font-size: 17px; } + +.top-nav { display: flex; gap: 20px; font-size: 14px; } +.top-nav a { + color: var(--dim); text-decoration: none; + transition: color 120ms ease; +} +.top-nav a:hover, .top-nav a:focus-visible { color: var(--fg); outline: none; } + +@media (max-width: 600px) { + .top-nav a:nth-child(1), + .top-nav a:nth-child(2), + .top-nav a:nth-child(3) { display: none; } +} + +/* Main layout ------------------------------------------------------------- */ + +main { + max-width: var(--max); + margin: 0 auto; + padding: 24px; +} + +section { padding: 56px 0; border-top: 1px solid var(--border); } +section:first-of-type { border-top: none; padding-top: 28px; } + +h1, h2, h3 { + letter-spacing: -0.02em; + line-height: 1.2; + font-weight: 700; + margin: 0 0 16px; +} + +h1 { font-size: clamp(2rem, 4.5vw, 3.25rem); } +h2 { font-size: clamp(1.4rem, 2.4vw, 1.8rem); margin-bottom: 12px; } +h3 { font-size: 1.05rem; } + +.prose { + max-width: 58ch; + color: var(--dim); + font-size: 1.0625rem; + margin: 0 0 24px; +} +.prose strong { color: var(--fg); font-weight: 600; } +.prose code, code { + font-family: var(--mono); + font-size: 0.92em; + padding: 1px 6px; + background: var(--surface-2); + border: 1px solid var(--border); + border-radius: 5px; + color: var(--fg); +} +.prose a { color: var(--accent); text-decoration: none; border-bottom: 1px solid transparent; } +.prose a:hover { border-bottom-color: var(--accent); } + +/* Hero -------------------------------------------------------------------- */ + +.hero { padding-top: 24px; } + +.eyebrow { + font-family: var(--mono); + font-size: 12.5px; + color: var(--dim); + letter-spacing: 0.08em; + text-transform: uppercase; + margin: 0 0 20px; +} + +.hero h1 { + max-width: 24ch; + margin-bottom: 20px; + font-size: clamp(1.1rem, 2vw, 1.5rem); + font-weight: 500; + color: var(--dim); + line-height: 1.35; + letter-spacing: -0.005em; +} +.hero h1 .accent { + display: block; + font-size: clamp(2.25rem, 5.5vw, 3.75rem); + font-weight: 700; + line-height: 1.05; + letter-spacing: -0.025em; + margin: 6px 0; +} +.accent { color: var(--accent); } + +.lede { + font-size: 1.1rem; + max-width: 60ch; + color: var(--dim); + margin: 0 0 32px; +} + +.pills { + display: flex; flex-wrap: wrap; gap: 8px; + margin: 24px 0 28px; +} +.pills span { + font-family: var(--mono); + font-size: 12.5px; + color: var(--dim); + padding: 5px 10px; + border: 1px solid var(--border); + border-radius: 999px; + background: var(--surface); +} +.cta { display: flex; flex-wrap: wrap; gap: 12px; } + +.btn { + display: inline-flex; align-items: center; justify-content: center; + padding: 10px 18px; + border-radius: 8px; + font-size: 14px; + font-weight: 600; + text-decoration: none; + border: 1px solid transparent; + transition: transform 80ms ease, background 120ms ease, border-color 120ms ease; +} +.btn:active { transform: translateY(1px); } +.btn-primary { background: var(--accent); color: #0a0a0b; } +.btn-primary:hover { background: #0fae79; } +.btn-ghost { background: transparent; color: var(--fg); border-color: var(--border); } +.btn-ghost:hover { border-color: var(--dim); } + +/* Terminal block ---------------------------------------------------------- */ + +.terminal { + position: relative; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + overflow: hidden; + margin: 12px 0; + box-shadow: 0 1px 0 rgba(255, 255, 255, 0.02) inset; +} +.terminal.secondary { background: var(--surface-2); } + +.terminal-chrome { + display: flex; align-items: center; gap: 8px; + padding: 10px 14px; + background: var(--surface-2); + border-bottom: 1px solid var(--border); +} +.terminal-chrome .dot { + width: 10px; height: 10px; border-radius: 50%; + background: var(--border); +} +.dot-r { background: #ef4444cc; } +.dot-y { background: #f59e0bcc; } +.dot-g { background: #22c55ecc; } +.terminal-title { + margin-left: 8px; + font-family: var(--mono); + font-size: 12px; + color: var(--dim); +} + +.terminal-body { + margin: 0; + padding: 18px 20px; + font-family: var(--mono); + font-size: 14.5px; + line-height: 1.55; + color: var(--fg); + overflow-x: auto; + white-space: pre; +} +.terminal-body code { background: none; border: none; padding: 0; font-size: inherit; } +.terminal.compact .terminal-body { padding: 14px 18px; font-size: 13.5px; } + +.prompt { color: var(--accent); margin-right: 10px; user-select: none; } +.muted { color: var(--dim); } + +.copy { + position: absolute; top: 10px; right: 10px; + background: var(--surface-2); + color: var(--dim); + border: 1px solid var(--border); + border-radius: 6px; + padding: 5px 10px; + font-family: var(--mono); + font-size: 12px; + cursor: pointer; + transition: color 120ms ease, border-color 120ms ease; +} +.terminal-chrome + .terminal-body + .copy, +.terminal-chrome ~ .copy { top: 50px; } +.copy:hover { color: var(--fg); border-color: var(--dim); } +.copy.copied { color: var(--accent); border-color: var(--accent); } + +/* Why grid ---------------------------------------------------------------- */ + +.grid { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 20px; + margin: 28px 0 0; +} +.grid > div { + padding: 20px; + border: 1px solid var(--border); + border-radius: var(--radius); + background: var(--surface); +} +.grid dt { font-weight: 600; margin-bottom: 6px; font-size: 15px; } +.grid dd { margin: 0; color: var(--dim); font-size: 14.5px; } + +@media (max-width: 780px) { + .grid { grid-template-columns: 1fr; } +} + +.grid-trust { grid-template-columns: repeat(2, minmax(0, 1fr)); } +@media (max-width: 780px) { .grid-trust { grid-template-columns: 1fr; } } + +/* How-it-works ASCII ------------------------------------------------------ */ + +.diagram { + font-family: var(--mono); + font-size: 14px; + line-height: 1.7; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius); + padding: 22px 24px; + overflow-x: auto; + color: var(--fg); + white-space: pre; +} + +/* Code grid --------------------------------------------------------------- */ + +.code-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 16px; + margin-top: 20px; +} +@media (max-width: 780px) { .code-grid { grid-template-columns: 1fr; } } + +.code-grid figure { + margin: 0; + border: 1px solid var(--border); + border-radius: var(--radius); + background: var(--surface); + overflow: hidden; +} +.code-grid figcaption { + padding: 8px 16px; + font-family: var(--mono); + font-size: 12px; + color: var(--dim); + background: var(--surface-2); + border-bottom: 1px solid var(--border); + letter-spacing: 0.04em; + text-transform: uppercase; +} +.code-grid pre { + margin: 0; + padding: 16px 18px; + font-family: var(--mono); + font-size: 13.5px; + line-height: 1.6; + overflow-x: auto; + white-space: pre; +} +.code-grid code { background: none; border: none; padding: 0; font-size: inherit; } +.cmt { color: var(--dim); } +.str { color: var(--accent); } + +/* Quality tiers ----------------------------------------------------------- */ + +.tiers { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 14px; + margin-top: 20px; +} +@media (max-width: 780px) { .tiers { grid-template-columns: 1fr; } } + +.tier { + padding: 20px; + border: 1px solid var(--border); + border-radius: var(--radius); + background: var(--surface); +} +.tier h3 { + font-family: var(--mono); + font-size: 15px; + color: var(--fg); + margin-bottom: 8px; + display: inline-flex; align-items: center; gap: 8px; +} +.tier p { color: var(--dim); font-size: 14.5px; margin: 0 0 12px; } +.tier code { display: inline-block; font-size: 12.5px; } + +.tier-featured { + border-color: var(--accent); + background: var(--accent-soft); +} +.default-tag { + font-family: var(--mono); + font-size: 11px; + color: #0a0a0b; + background: var(--accent); + padding: 2px 8px; + border-radius: 999px; + letter-spacing: 0.03em; + text-transform: uppercase; +} +@media (prefers-color-scheme: light) { + .default-tag { color: #fff; } +} + +/* Providers --------------------------------------------------------------- */ + +.provider-grid { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 14px; + margin-top: 20px; +} +@media (max-width: 780px) { .provider-grid { grid-template-columns: 1fr; } } + +.provider { + padding: 18px; + border: 1px solid var(--border); + border-radius: var(--radius); + background: var(--surface); +} +.provider h3 { + margin-bottom: 6px; + display: inline-flex; align-items: center; gap: 8px; flex-wrap: wrap; +} +.provider p { margin: 0; color: var(--dim); font-size: 14px; font-family: var(--mono); } + +.provider-subhead { + margin-top: 28px; + font-family: var(--mono); + font-size: 13px; + color: var(--dim); + text-transform: uppercase; + letter-spacing: 0.06em; + font-weight: 600; +} + +.status-tag { + font-family: var(--mono); + font-size: 10.5px; + color: #0a0a0b; + background: var(--accent); + padding: 2px 8px; + border-radius: 999px; + letter-spacing: 0.04em; + text-transform: uppercase; + font-weight: 700; +} +.status-tag.tag-next { + background: transparent; + color: var(--dim); + border: 1px solid var(--border); +} +@media (prefers-color-scheme: light) { + .status-tag { color: #fff; } + .status-tag.tag-next { color: var(--dim); } +} + +/* Install block ----------------------------------------------------------- */ + +.install-block h3 { + margin-top: 28px; + font-family: var(--mono); + font-size: 13px; + color: var(--dim); + text-transform: uppercase; + letter-spacing: 0.06em; +} +.fine { + margin-top: 18px; + color: var(--dim); + font-size: 13.5px; +} +.fine a { color: var(--accent); text-decoration: none; border-bottom: 1px solid transparent; } +.fine a:hover { border-bottom-color: var(--accent); } + +/* Footer ------------------------------------------------------------------ */ + +.bottom { + border-top: 1px solid var(--border); + margin-top: 40px; + padding: 28px 24px; +} +.bottom-inner { + max-width: var(--max); + margin: 0 auto; + display: flex; align-items: center; justify-content: space-between; gap: 20px; + flex-wrap: wrap; + font-size: 14px; + color: var(--dim); +} +.bottom-inner > div { display: inline-flex; align-items: center; gap: 10px; } +.bottom nav { display: flex; flex-wrap: wrap; gap: 18px; } +.bottom a { color: var(--dim); text-decoration: none; } +.bottom a:hover { color: var(--fg); } + +/* Scrollbar polish (WebKit only; graceful skip elsewhere) ----------------- */ + +::-webkit-scrollbar { height: 10px; width: 10px; } +::-webkit-scrollbar-thumb { background: var(--border); border-radius: 10px; } +::-webkit-scrollbar-track { background: transparent; } + +/* Reduced motion ---------------------------------------------------------- */ + +@media (prefers-reduced-motion: reduce) { + *, *::before, *::after { + animation-duration: 0.01ms !important; + transition-duration: 0.01ms !important; + } +} diff --git a/fly.toml b/fly.toml index d421e21..94c6a18 100644 --- a/fly.toml +++ b/fly.toml @@ -3,6 +3,12 @@ primary_region = "ord" [build] +# FRUGAL_AUTH_TOKEN must be set via `fly secrets set FRUGAL_AUTH_TOKEN=...` +# before the first deploy. The binary refuses to start on a non-loopback bind +# unless either the token or FRUGAL_ALLOW_UNAUTH=1 is present. +[env] + FRUGAL_ADDR = "0.0.0.0:8080" + [http_service] internal_port = 8080 force_https = true diff --git a/go.mod b/go.mod index b331a6f..dcda363 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,23 @@ module github.com/frugalsh/frugal -go 1.26.1 +go 1.24.0 require ( github.com/go-chi/chi/v5 v5.2.5 + golang.org/x/time v0.9.0 gopkg.in/yaml.v3 v3.0.1 ) + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_golang v1.23.2 + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/sys v0.35.0 // indirect + google.golang.org/protobuf v1.36.8 // indirect +) diff --git a/go.sum b/go.sum index 0c0854f..4d0c19f 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,50 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/install.sh b/install.sh deleted file mode 100755 index f9c4a38..0000000 --- a/install.sh +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -# frugal.sh installer -# Usage: curl -fsSL https://frugal.sh/install | sh - -REPO="frugalsh/frugal" -INSTALL_DIR="${FRUGAL_INSTALL_DIR:-$HOME/.frugal}" -BIN_DIR="$INSTALL_DIR/bin" -CONFIG_DIR="$INSTALL_DIR/config" - -# ---- helpers ---- - -info() { printf "\033[1;34m==>\033[0m %s\n" "$1"; } -ok() { printf "\033[1;32m ✓\033[0m %s\n" "$1"; } -warn() { printf "\033[1;33m !\033[0m %s\n" "$1"; } -fail() { printf "\033[1;31m ✗\033[0m %s\n" "$1" >&2; exit 1; } - -detect_platform() { - local os arch - os="$(uname -s | tr '[:upper:]' '[:lower:]')" - arch="$(uname -m)" - - case "$arch" in - x86_64|amd64) arch="amd64" ;; - arm64|aarch64) arch="arm64" ;; - *) fail "unsupported architecture: $arch" ;; - esac - - case "$os" in - linux) echo "linux-${arch}" ;; - darwin) echo "darwin-${arch}" ;; - *) fail "unsupported OS: $os" ;; - esac -} - -latest_version() { - if command -v curl &>/dev/null; then - curl -fsSL "https://api.github.com/repos/${REPO}/releases/latest" | grep '"tag_name"' | cut -d'"' -f4 - elif command -v wget &>/dev/null; then - wget -qO- "https://api.github.com/repos/${REPO}/releases/latest" | grep '"tag_name"' | cut -d'"' -f4 - else - fail "curl or wget required" - fi -} - -download() { - local url="$1" dest="$2" - if command -v curl &>/dev/null; then - curl -fsSL "$url" -o "$dest" - else - wget -qO "$dest" "$url" - fi -} - -# ---- install ---- - -main() { - info "installing frugal.sh — the open-source LLM cost optimizer" - echo - - # Detect platform - local platform - platform="$(detect_platform)" - ok "detected platform: $platform" - - # Get latest version (or build from source if no releases yet) - local version - version="$(latest_version 2>/dev/null || echo "")" - - mkdir -p "$BIN_DIR" "$CONFIG_DIR" - - if [ -n "$version" ]; then - info "downloading frugal $version for $platform..." - local url="https://github.com/${REPO}/releases/download/${version}/frugal-${platform}" - download "$url" "$BIN_DIR/frugal" - chmod +x "$BIN_DIR/frugal" - ok "downloaded frugal $version" - else - # No releases yet — build from source - info "no release found, building from source..." - if ! command -v go &>/dev/null; then - fail "go is required to build from source (install: https://go.dev/dl/)" - fi - - local tmpdir - tmpdir="$(mktemp -d)" - trap "rm -rf $tmpdir" EXIT - - if command -v git &>/dev/null; then - git clone --depth 1 "https://github.com/${REPO}.git" "$tmpdir/frugal" 2>/dev/null - else - download "https://github.com/${REPO}/archive/refs/heads/main.tar.gz" "$tmpdir/frugal.tar.gz" - tar -xzf "$tmpdir/frugal.tar.gz" -C "$tmpdir" - mv "$tmpdir/frugal-main" "$tmpdir/frugal" - fi - - (cd "$tmpdir/frugal" && go build -o "$BIN_DIR/frugal" ./cmd/frugal) - cp "$tmpdir/frugal/config/models.yaml" "$CONFIG_DIR/models.yaml" - ok "built frugal from source" - fi - - # Download default config if not present - if [ ! -f "$CONFIG_DIR/models.yaml" ]; then - info "downloading model config..." - download "https://raw.githubusercontent.com/${REPO}/main/config/models.yaml" "$CONFIG_DIR/models.yaml" - ok "config saved to $CONFIG_DIR/models.yaml" - fi - - echo - info "detecting API keys..." - - local keys_found=0 - [ -n "${OPENAI_API_KEY:-}" ] && { ok "OPENAI_API_KEY found"; keys_found=$((keys_found + 1)); } - [ -n "${ANTHROPIC_API_KEY:-}" ] && { ok "ANTHROPIC_API_KEY found"; keys_found=$((keys_found + 1)); } - [ -n "${GOOGLE_API_KEY:-}" ] && { ok "GOOGLE_API_KEY found"; keys_found=$((keys_found + 1)); } - - if [ "$keys_found" -eq 0 ]; then - warn "no API keys found in environment" - echo " Set at least one of: OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY" - echo " Then run: frugal" - echo - fi - - # Add to PATH - local shell_config="" - local export_line="export PATH=\"$BIN_DIR:\$PATH\"" - local config_line="export FRUGAL_CONFIG=\"$CONFIG_DIR/models.yaml\"" - - if [ -f "$HOME/.zshrc" ]; then - shell_config="$HOME/.zshrc" - elif [ -f "$HOME/.bashrc" ]; then - shell_config="$HOME/.bashrc" - elif [ -f "$HOME/.bash_profile" ]; then - shell_config="$HOME/.bash_profile" - fi - - if [ -n "$shell_config" ]; then - if ! grep -q ".frugal/bin" "$shell_config" 2>/dev/null; then - echo "" >> "$shell_config" - echo "# frugal.sh" >> "$shell_config" - echo "$export_line" >> "$shell_config" - echo "$config_line" >> "$shell_config" - ok "added to PATH in $shell_config" - fi - fi - - # Also export for current session - export PATH="$BIN_DIR:$PATH" - export FRUGAL_CONFIG="$CONFIG_DIR/models.yaml" - - echo - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo - echo " frugal.sh installed!" - echo - echo " Start the proxy:" - echo " frugal" - echo - echo " Then point your app at it:" - echo " export OPENAI_BASE_URL=http://localhost:8080/v1" - echo - echo " That's it. Same code. Same SDK. Lower bill." - echo - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo -} - -main "$@" diff --git a/internal/classifier/classifier_test.go b/internal/classifier/classifier_test.go index e59d144..cd01aa0 100644 --- a/internal/classifier/classifier_test.go +++ b/internal/classifier/classifier_test.go @@ -41,6 +41,37 @@ func TestClassify_SimpleQuestion(t *testing.T) { } } +func TestClassify_MultimodalSetsRequiresVision(t *testing.T) { + c := NewRuleBased() + req := &types.ChatCompletionRequest{ + Messages: []types.Message{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"text","text":"describe"}, + {"type":"image_url","image_url":{"url":"data:image/png;base64,AAAA"}} + ]`)}, + }, + } + + f := c.Classify(req) + + if !f.RequiresVision { + t.Fatalf("expected RequiresVision=true for multimodal input") + } +} + +func TestClassify_PlainStringDoesNotRequireVision(t *testing.T) { + c := NewRuleBased() + req := &types.ChatCompletionRequest{ + Messages: []types.Message{msg("user", "hello")}, + } + + f := c.Classify(req) + + if f.RequiresVision { + t.Fatalf("expected RequiresVision=false for text-only input") + } +} + func TestClassify_CodeRequest(t *testing.T) { c := NewRuleBased() req := &types.ChatCompletionRequest{ @@ -77,6 +108,24 @@ func TestClassify_MathRequest(t *testing.T) { } } +func TestClassify_CaseInsensitiveKeywordDetection(t *testing.T) { + c := NewRuleBased() + req := &types.ChatCompletionRequest{ + Messages: []types.Message{ + msg("user", "Write a Function that solves this Equation"), + }, + } + + f := c.Classify(req) + + if !f.HasCode { + t.Error("expected HasCode=true for mixed-case coding keyword") + } + if !f.HasMath { + t.Error("expected HasMath=true for mixed-case math keyword") + } +} + func TestClassify_WithSystemPrompt(t *testing.T) { c := NewRuleBased() longSystem := "You are a helpful assistant. " // short diff --git a/internal/classifier/features.go b/internal/classifier/features.go index bf7be59..272b7a5 100644 --- a/internal/classifier/features.go +++ b/internal/classifier/features.go @@ -8,29 +8,45 @@ import ( ) var ( - codeBlockRe = regexp.MustCompile("(?s)```") - codeFuncRe = regexp.MustCompile(`\b(function|def|class|import|export|package|struct|interface|impl|fn|pub|const|let|var)\b`) + // Require a fenced block that spans at least 3 newlines before the + // terminator so backtick fragments in prose don't flag HasCode. + codeBlockRe = regexp.MustCompile("(?s)```[^`]{0,20}\n[^`]*\n[^`]*\n[^`]*```") + codeFuncRe = regexp.MustCompile(`(?i)\b(function|def|class|import|export|package|struct|interface|impl|fn|pub|const|let|var)\b`) mathLatexRe = regexp.MustCompile(`\$[^$]+\$|\\(begin|end)\{|\\frac|\\sum|\\int|\\sqrt`) - mathKeywordRe = regexp.MustCompile(`\b(equation|derivative|integral|matrix|eigenvalue|polynomial|theorem|proof|calculus)\b`) + mathKeywordRe = regexp.MustCompile(`(?i)\b(equation|derivative|integral|matrix|eigenvalue|polynomial|theorem|proof|calculus)\b`) ) +// perProviderCharsPerToken is a conservative divisor applied to the raw +// character count when estimating input tokens. OpenAI BPE averages ~4 for +// English text; Anthropic and Google tokenizers emit more fine-grained pieces, +// so we round down to 3.3 (1/0.3) to avoid under-estimating and selecting a +// model whose context window doesn't actually fit. +const anthropicCharsPerToken = 3.3 + func extractFeatures(req *types.ChatCompletionRequest) types.QueryFeatures { var f types.QueryFeatures allText := concatenateMessages(req.Messages) - // Token estimation (~4 chars per token) - f.EstimatedInputTokens = len(allText) / 4 - if f.EstimatedInputTokens < 1 { - f.EstimatedInputTokens = 1 + // Token estimation with a conservative lower bound. Classifier routes on + // the higher of two estimates so we never pick a model whose context + // window is tighter than the actual tokenized input would produce. + charsPerTokenUpper := 4.0 + est := int(float64(len(allText)) / anthropicCharsPerToken) + if v := len(allText) / int(charsPerTokenUpper); v > est { + est = v + } + if est < 1 { + est = 1 } + f.EstimatedInputTokens = est f.EstimatedOutputTokens = estimateOutputTokens(req) // System prompt analysis for _, msg := range req.Messages { if msg.Role == "system" { f.HasSystemPrompt = true - f.SystemPromptLength = len(msg.ContentString()) + f.SystemPromptLength = len(msg.ContentText()) break } } @@ -49,6 +65,16 @@ func extractFeatures(req *types.ChatCompletionRequest) types.QueryFeatures { // Output format requirements f.RequiresJSON = req.ResponseFormat != nil && req.ResponseFormat.Type == "json_object" f.RequiresToolUse = len(req.Tools) > 0 + f.RequiresMultipleCompletions = req.N != nil && *req.N > 1 + + // Vision: any message carrying non-text content (image_url, input_audio) + // forces the router to only consider vision-capable models. + for _, msg := range req.Messages { + if msg.HasNonTextContent() { + f.RequiresVision = true + break + } + } // Domain hints f.DomainHints = detectDomains(allText) @@ -62,7 +88,7 @@ func extractFeatures(req *types.ChatCompletionRequest) types.QueryFeatures { func concatenateMessages(msgs []types.Message) string { var b strings.Builder for _, m := range msgs { - b.WriteString(m.ContentString()) + b.WriteString(m.ContentText()) b.WriteByte(' ') } return b.String() @@ -75,7 +101,7 @@ func estimateOutputTokens(req *types.ChatCompletionRequest) int { // Default estimate based on input size inputChars := 0 for _, m := range req.Messages { - inputChars += len(m.ContentString()) + inputChars += len(m.ContentText()) } est := inputChars / 4 // rough: output ~= input length if est < 100 { @@ -87,40 +113,33 @@ func estimateOutputTokens(req *types.ChatCompletionRequest) int { return est } +// Precompiled word-boundary regexes per domain. The previous implementation +// used substring matching which flagged "import" in "important", "api" in +// "therapist", and similar false positives. +var ( + codingDomainRe = regexp.MustCompile(`(?i)\b(code|function|bug|error|compile|debug|api|endpoint|database|sql|algorithm|programming)\b`) + creativeDomainRe = regexp.MustCompile(`(?i)\b(story|poem|creative|imagine|fiction|essay|blog)\b|write me\b`) + mathDomainRe = regexp.MustCompile(`(?i)\b(calculate|solve|equation|math|formula|compute|probability|statistics)\b`) +) + +// detectDomains flags the top-level topic of the query. The "analysis" domain +// was removed because its keyword set fired on nearly every request (summarize, +// explain, review, compare are boilerplate instructions), which made it +// useless as a routing signal. func detectDomains(text string) []string { - lower := strings.ToLower(text) var hints []string - - codingKeywords := []string{"code", "function", "bug", "error", "compile", "debug", "api", "endpoint", "database", "sql", "algorithm", "programming"} - creativeKeywords := []string{"story", "poem", "creative", "write me", "imagine", "fiction", "essay", "blog"} - analysisKeywords := []string{"analyze", "compare", "evaluate", "assess", "review", "summarize", "explain"} - mathKeywords := []string{"calculate", "solve", "equation", "math", "formula", "compute", "probability", "statistics"} - - if matchesAny(lower, codingKeywords) { + if codingDomainRe.MatchString(text) { hints = append(hints, "coding") } - if matchesAny(lower, creativeKeywords) { + if creativeDomainRe.MatchString(text) { hints = append(hints, "creative") } - if matchesAny(lower, analysisKeywords) { - hints = append(hints, "analysis") - } - if matchesAny(lower, mathKeywords) { + if mathDomainRe.MatchString(text) { hints = append(hints, "math") } - return hints } -func matchesAny(text string, keywords []string) bool { - for _, kw := range keywords { - if strings.Contains(text, kw) { - return true - } - } - return false -} - func computeComplexity(f types.QueryFeatures) float64 { score := 0.0 diff --git a/internal/config/config.go b/internal/config/config.go index 146008a..f39e49b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,9 @@ package config import ( + "bytes" "fmt" + "math" "os" "gopkg.in/yaml.v3" @@ -31,7 +33,14 @@ type CapabilityConfig struct { InstructionFollowing float64 `yaml:"instruction_following"` ToolUse bool `yaml:"tool_use"` JSONMode bool `yaml:"json_mode"` + Vision bool `yaml:"vision"` MaxContext int `yaml:"max_context"` + // Source names the benchmark suite these scores were derived from + // (e.g. "livebench+aider"). AsOf is an ISO-8601 date string so + // operators know when the scores were last refreshed. Routing + // decisions are only as defensible as these fields — keep them current. + Source string `yaml:"source,omitempty"` + AsOf string `yaml:"as_of,omitempty"` } type ThresholdConfig struct { @@ -53,9 +62,97 @@ func Load(path string) (*Config, error) { } var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + if err := dec.Decode(&cfg); err != nil { return nil, fmt.Errorf("parsing config: %w", err) } + if err := validate(&cfg); err != nil { + return nil, fmt.Errorf("validating config: %w", err) + } + return &cfg, nil } + +func validate(cfg *Config) error { + if len(cfg.Providers) == 0 { + return fmt.Errorf("providers must contain at least one provider") + } + + for providerName, provider := range cfg.Providers { + if provider.APIKeyEnv == "" { + return fmt.Errorf("providers.%s.api_key_env is required", providerName) + } + if len(provider.Models) == 0 { + return fmt.Errorf("providers.%s.models must contain at least one model", providerName) + } + + for modelName, model := range provider.Models { + if !isFiniteNonNegative(model.CostPer1KInput) { + return fmt.Errorf("providers.%s.models.%s.cost_per_1k_input must be a finite number >= 0", providerName, modelName) + } + if !isFiniteNonNegative(model.CostPer1KOutput) { + return fmt.Errorf("providers.%s.models.%s.cost_per_1k_output must be a finite number >= 0", providerName, modelName) + } + if err := validateCapabilityRange(providerName, modelName, "reasoning", model.Capabilities.Reasoning); err != nil { + return err + } + if err := validateCapabilityRange(providerName, modelName, "coding", model.Capabilities.Coding); err != nil { + return err + } + if err := validateCapabilityRange(providerName, modelName, "creative", model.Capabilities.Creative); err != nil { + return err + } + if err := validateCapabilityRange(providerName, modelName, "instruction_following", model.Capabilities.InstructionFollowing); err != nil { + return err + } + if model.Capabilities.MaxContext < 0 { + return fmt.Errorf("providers.%s.models.%s.capabilities.max_context must be >= 0", providerName, modelName) + } + } + } + + if len(cfg.QualityThresholds) == 0 { + return fmt.Errorf("quality_thresholds must contain at least one tier") + } + + for tier, threshold := range cfg.QualityThresholds { + if err := validateThresholdRange(tier, "min_reasoning", threshold.MinReasoning); err != nil { + return err + } + if err := validateThresholdRange(tier, "min_coding", threshold.MinCoding); err != nil { + return err + } + if err := validateThresholdRange(tier, "min_creative", threshold.MinCreative); err != nil { + return err + } + if err := validateThresholdRange(tier, "min_instruction_following", threshold.MinInstructionFollowing); err != nil { + return err + } + } + + return nil +} + +func isFiniteNonNegative(v float64) bool { + return !math.IsNaN(v) && !math.IsInf(v, 0) && v >= 0 +} + +func isFiniteProbability(v float64) bool { + return !math.IsNaN(v) && !math.IsInf(v, 0) && v >= 0 && v <= 1 +} + +func validateCapabilityRange(providerName, modelName, field string, value float64) error { + if !isFiniteProbability(value) { + return fmt.Errorf("providers.%s.models.%s.capabilities.%s must be between 0 and 1", providerName, modelName, field) + } + return nil +} + +func validateThresholdRange(tier, field string, value float64) error { + if !isFiniteProbability(value) { + return fmt.Errorf("quality_thresholds.%s.%s must be between 0 and 1", tier, field) + } + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 59f261f..54aca3b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,6 +3,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" ) @@ -77,3 +78,148 @@ func TestLoad_MissingFile(t *testing.T) { t.Error("expected error for missing file") } } + +func TestLoad_RejectsUnknownFields(t *testing.T) { + content := ` +providers: + openai: + api_key_env: OPENAI_API_KEY + base_url: https://api.openai.com/v1 + models: + gpt-4o: + cost_per_1k_input: 0.0025 + cost_per_1k_output: 0.01 + capabilities: + reasoning: 0.95 + coding: 0.92 + creative: 0.90 + instruction_following: 0.95 + tool_use: true + json_mode: true + max_context: 128000 + typo_field: true +quality_thresholds: + balanced: + min_reasoning: 0.70 + min_coding: 0.68 + min_creative: 0.65 + min_instruction_following: 0.72 +` + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for unknown field") + } + if !strings.Contains(err.Error(), "typo_field") { + t.Fatalf("expected unknown field error to mention typo_field, got: %v", err) + } +} + +func TestLoad_RejectsMissingProviders(t *testing.T) { + content := ` +quality_thresholds: + balanced: + min_reasoning: 0.70 + min_coding: 0.68 + min_creative: 0.65 + min_instruction_following: 0.72 +` + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected validation error for missing providers") + } + if !strings.Contains(err.Error(), "providers must contain at least one provider") { + t.Fatalf("expected providers validation error, got: %v", err) + } +} + +func TestLoad_RejectsOutOfRangeCapability(t *testing.T) { + content := ` +providers: + openai: + api_key_env: OPENAI_API_KEY + base_url: https://api.openai.com/v1 + models: + gpt-4o: + cost_per_1k_input: 0.0025 + cost_per_1k_output: 0.01 + capabilities: + reasoning: 1.2 + coding: 0.92 + creative: 0.90 + instruction_following: 0.95 + tool_use: true + json_mode: true + max_context: 128000 +quality_thresholds: + balanced: + min_reasoning: 0.70 + min_coding: 0.68 + min_creative: 0.65 + min_instruction_following: 0.72 +` + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected validation error for capability range") + } + if !strings.Contains(err.Error(), "capabilities.reasoning") { + t.Fatalf("expected capability reasoning validation error, got: %v", err) + } +} + +func TestLoad_RejectsNegativeModelCost(t *testing.T) { + content := ` +providers: + openai: + api_key_env: OPENAI_API_KEY + base_url: https://api.openai.com/v1 + models: + gpt-4o: + cost_per_1k_input: -0.01 + cost_per_1k_output: 0.01 + capabilities: + reasoning: 0.95 + coding: 0.92 + creative: 0.90 + instruction_following: 0.95 + tool_use: true + json_mode: true + max_context: 128000 +quality_thresholds: + balanced: + min_reasoning: 0.70 + min_coding: 0.68 + min_creative: 0.65 + min_instruction_following: 0.72 +` + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected validation error for negative model cost") + } + if !strings.Contains(err.Error(), "cost_per_1k_input") { + t.Fatalf("expected cost_per_1k_input validation error, got: %v", err) + } +} diff --git a/internal/eval/eval.go b/internal/eval/eval.go new file mode 100644 index 0000000..4e23379 --- /dev/null +++ b/internal/eval/eval.go @@ -0,0 +1,79 @@ +// Package eval runs simulation-only evaluations of Frugal's routing decisions +// against a baseline model, so savings claims can be reproduced without spending +// real API budget. +package eval + +import ( + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +// Result is the outcome of routing one query through the eval harness. +type Result struct { + Query Query + Decision types.RoutingDecision + BaselineCost float64 + FrugalCost float64 + SavingsPct float64 +} + +// Summary aggregates results across one workload run. +type Summary struct { + Workload string + Quality types.QualityThreshold + BaselineModel string + QueryCount int + TotalBaseline float64 + TotalFrugal float64 + SavingsPct float64 + Results []Result +} + +// Runner evaluates workloads against a router + classifier, compared to a +// single baseline model (e.g. "gpt-4o") that represents "without Frugal" cost. +type Runner struct { + Router *router.Router + Classifier classifier.Classifier + BaselineModel string + BaselineCPKIn float64 + BaselineCPKOut float64 +} + +// Run evaluates every query in the workload at the given quality threshold +// and returns a populated Summary. +func (r *Runner) Run(w Workload, quality types.QualityThreshold) Summary { + s := Summary{ + Workload: w.Name, + Quality: quality, + BaselineModel: r.BaselineModel, + QueryCount: len(w.Queries), + } + for _, q := range w.Queries { + features := r.Classifier.Classify(q.Request) + decision := r.Router.Route(features, quality, nil) + frugalCost := decision.EstimatedCost + baselineCost := baselineCostFor(features, r.BaselineCPKIn, r.BaselineCPKOut) + var savings float64 + if baselineCost > 0 { + savings = (baselineCost - frugalCost) / baselineCost * 100 + } + s.Results = append(s.Results, Result{ + Query: q, + Decision: decision, + BaselineCost: baselineCost, + FrugalCost: frugalCost, + SavingsPct: savings, + }) + s.TotalBaseline += baselineCost + s.TotalFrugal += frugalCost + } + if s.TotalBaseline > 0 { + s.SavingsPct = (s.TotalBaseline - s.TotalFrugal) / s.TotalBaseline * 100 + } + return s +} + +func baselineCostFor(f types.QueryFeatures, cpkIn, cpkOut float64) float64 { + return float64(f.EstimatedInputTokens)/1000*cpkIn + float64(f.EstimatedOutputTokens)/1000*cpkOut +} diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go new file mode 100644 index 0000000..c72d897 --- /dev/null +++ b/internal/eval/eval_test.go @@ -0,0 +1,81 @@ +package eval + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +func TestRunnerReportsSavingsAgainstExpensiveBaseline(t *testing.T) { + models := []router.ModelEntry{ + { + Name: "cheap", Provider: "x", + CostPer1KInput: 0.001, CostPer1KOutput: 0.002, + Reasoning: 0.5, Coding: 0.5, Creative: 0.5, InstructFollowing: 0.5, + MaxContext: 100000, + }, + { + Name: "expensive", Provider: "x", + CostPer1KInput: 0.01, CostPer1KOutput: 0.03, + Reasoning: 0.95, Coding: 0.95, Creative: 0.95, InstructFollowing: 0.95, + MaxContext: 100000, + }, + } + thresholds := map[string]router.Threshold{ + "cost": {MinReasoning: 0.3, MinCoding: 0.3, MinCreative: 0.3, MinInstructFollowing: 0.3}, + "balanced": {MinReasoning: 0.6, MinCoding: 0.6, MinCreative: 0.6, MinInstructFollowing: 0.6}, + } + + r := &Runner{ + Router: router.New(models, thresholds), + Classifier: classifier.NewRuleBased(), + BaselineModel: "expensive", + BaselineCPKIn: 0.01, + BaselineCPKOut: 0.03, + } + + w := Workload{ + Name: "smoke", + Queries: []Query{ + { + Label: "trivial", + Request: &types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: json.RawMessage(`"hi there"`)}}, + }, + }, + }, + } + + s := r.Run(w, types.QualityCost) + + if s.QueryCount != 1 || len(s.Results) != 1 { + t.Fatalf("expected 1 query/result, got count=%d results=%d", s.QueryCount, len(s.Results)) + } + if s.TotalBaseline <= 0 { + t.Fatalf("expected baseline > 0, got %f", s.TotalBaseline) + } + if s.TotalFrugal >= s.TotalBaseline { + t.Fatalf("expected Frugal < baseline on trivial query; frugal=%f baseline=%f", s.TotalFrugal, s.TotalBaseline) + } + if s.SavingsPct <= 0 { + t.Fatalf("expected positive savings, got %.2f%%", s.SavingsPct) + } + + var buf bytes.Buffer + if err := WriteMarkdown(&buf, s); err != nil { + t.Fatalf("WriteMarkdown: %v", err) + } + out := buf.String() + if !strings.Contains(out, "Workload: smoke") { + t.Fatalf("report missing workload header: %s", out) + } + if !strings.Contains(out, "savings") { + t.Fatalf("report missing aggregate line: %s", out) + } +} diff --git a/internal/eval/live.go b/internal/eval/live.go new file mode 100644 index 0000000..b162688 --- /dev/null +++ b/internal/eval/live.go @@ -0,0 +1,277 @@ +package eval + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "time" + + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/config" + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +// LiveRunner executes real ChatCompletion calls for each problem in a +// workload, scores the output with the problem's scorer, and tracks actual +// cost using per-token pricing pulled from the shipped config. It runs each +// problem twice — once through the Frugal router, once pinned to the +// workload's baseline model — so every problem has an apples-to-apples pair. +// +// Concurrency is deliberately left to the caller: the CLI layer runs +// sequentially today because provider rate limits dominate wall time. A +// goroutine pool can be layered on top without touching this struct. +type LiveRunner struct { + Router *router.Router + Classifier classifier.Classifier + Registry *provider.Registry + // ModelCosts maps model name → per-1k-token cost for input/output. + // Populated from config.Config so the runner can compute real cost from + // the usage numbers returned by providers. + ModelCosts map[string]ModelCost +} + +// ModelCost is the per-1k-token price a LiveRunner uses when billing a +// response. Mirrors the relevant subset of config.ModelConfig so the runner +// doesn't depend on the config package's full struct. +type ModelCost struct { + InputPer1K float64 + OutputPer1K float64 +} + +// NewLiveRunner builds a runner from a loaded config + registry, so the +// caller doesn't have to flatten the cost map by hand. +func NewLiveRunner(cfg *config.Config, cls classifier.Classifier, rtr *router.Router, reg *provider.Registry) *LiveRunner { + costs := make(map[string]ModelCost) + for _, pc := range cfg.Providers { + for name, mc := range pc.Models { + costs[name] = ModelCost{ + InputPer1K: mc.CostPer1KInput, + OutputPer1K: mc.CostPer1KOutput, + } + } + } + return &LiveRunner{Router: rtr, Classifier: cls, Registry: reg, ModelCosts: costs} +} + +// LiveProblemResult holds one pair of (frugal, baseline) outcomes for a +// single problem. Costs are real (from provider-reported usage) when the +// provider returns a Usage block; fall back to router estimates otherwise. +type LiveProblemResult struct { + ProblemID string + // Frugal leg. + FrugalModel string + FrugalProvider string + FrugalOutput string + FrugalPass bool + FrugalDetail string + FrugalCostUSD float64 + FrugalLatencyMS int64 + FrugalErr string + // Baseline leg. + BaselineModel string + BaselineOutput string + BaselinePass bool + BaselineDetail string + BaselineCostUSD float64 + BaselineLatencyMS int64 + BaselineErr string +} + +// LiveSummary aggregates results across a workload run. Pass-rate is the +// share of problems each leg scored correct; cost is the total USD spent +// across all problems for that leg. +type LiveSummary struct { + Workload string + Quality types.QualityThreshold + Baseline string + ProblemCount int + FrugalPassRate float64 + BaselinePassRate float64 + FrugalCostUSD float64 + BaselineCostUSD float64 + SavingsPct float64 + QualityDeltaPP float64 // baseline pass-rate minus frugal pass-rate, in percentage points + Results []LiveProblemResult + // ModelBreakdown: how often each model was selected by Frugal routing. + ModelBreakdown map[string]int +} + +// Run executes every problem in w sequentially. ctx is threaded into every +// upstream call so cancellation works. Problem-level errors (network, +// invalid response) are captured in the result rather than aborting the run: +// a noisy upstream shouldn't erase the whole benchmark report. +func (r *LiveRunner) Run(ctx context.Context, w LiveWorkload, quality types.QualityThreshold) (LiveSummary, error) { + if _, err := r.Registry.Resolve(w.Baseline); err != nil { + return LiveSummary{}, fmt.Errorf("baseline model %q not registered: %w", w.Baseline, err) + } + + s := LiveSummary{ + Workload: w.Name, + Quality: quality, + Baseline: w.Baseline, + ProblemCount: len(w.Problems), + ModelBreakdown: map[string]int{}, + } + + var frugalPass, baselinePass int + + for _, p := range w.Problems { + scorer, err := p.Scorer() + if err != nil { + // Already validated at load; guard anyway. + return s, err + } + + req := buildRequest(p) + features := r.Classifier.Classify(req) + decision := r.Router.Route(features, quality, nil) + + // --- Frugal leg --- + pr := LiveProblemResult{ProblemID: p.ID} + pr.FrugalModel = decision.SelectedModel + pr.FrugalProvider = decision.SelectedProvider + s.ModelBreakdown[decision.SelectedModel]++ + + if decision.SelectedModel == "" { + pr.FrugalErr = "router returned no model" + } else if prov, err := r.Registry.Resolve(decision.SelectedModel); err != nil { + pr.FrugalErr = err.Error() + } else { + pr.FrugalOutput, pr.FrugalCostUSD, pr.FrugalLatencyMS, pr.FrugalErr = + r.callAndCost(ctx, prov, decision.SelectedModel, req) + res := scorer.Score(pr.FrugalOutput) + pr.FrugalPass = res.Pass + pr.FrugalDetail = res.Detail + if pr.FrugalPass { + frugalPass++ + } + } + + // --- Baseline leg --- + pr.BaselineModel = w.Baseline + if prov, err := r.Registry.Resolve(w.Baseline); err != nil { + pr.BaselineErr = err.Error() + } else { + pr.BaselineOutput, pr.BaselineCostUSD, pr.BaselineLatencyMS, pr.BaselineErr = + r.callAndCost(ctx, prov, w.Baseline, req) + res := scorer.Score(pr.BaselineOutput) + pr.BaselinePass = res.Pass + pr.BaselineDetail = res.Detail + if pr.BaselinePass { + baselinePass++ + } + } + + s.FrugalCostUSD += pr.FrugalCostUSD + s.BaselineCostUSD += pr.BaselineCostUSD + s.Results = append(s.Results, pr) + } + + if s.ProblemCount > 0 { + s.FrugalPassRate = float64(frugalPass) / float64(s.ProblemCount) * 100 + s.BaselinePassRate = float64(baselinePass) / float64(s.ProblemCount) * 100 + } + s.QualityDeltaPP = s.BaselinePassRate - s.FrugalPassRate + if s.BaselineCostUSD > 0 { + s.SavingsPct = (s.BaselineCostUSD - s.FrugalCostUSD) / s.BaselineCostUSD * 100 + } + + return s, nil +} + +// callAndCost runs one non-streaming ChatCompletion and extracts a string +// output + real cost from the Usage block. Real cost is preferred so reports +// reflect what the provider actually billed; if Usage is absent, fall back +// to zero rather than fabricating a number. +func (r *LiveRunner) callAndCost(ctx context.Context, prov provider.Provider, model string, req *types.ChatCompletionRequest) (output string, cost float64, latencyMS int64, errMsg string) { + start := time.Now() + resp, err := prov.ChatCompletion(ctx, model, req) + latencyMS = time.Since(start).Milliseconds() + if err != nil { + return "", 0, latencyMS, err.Error() + } + if len(resp.Choices) == 0 { + return "", 0, latencyMS, "empty choices" + } + output = extractText(resp.Choices[0].Message.Content) + + if resp.Usage != nil { + mc := r.ModelCosts[model] + cost = float64(resp.Usage.PromptTokens)/1000*mc.InputPer1K + + float64(resp.Usage.CompletionTokens)/1000*mc.OutputPer1K + } + return output, cost, latencyMS, "" +} + +func extractText(raw json.RawMessage) string { + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + return "" +} + +func buildRequest(p Problem) *types.ChatCompletionRequest { + var msgs []types.Message + if p.System != "" { + content, _ := json.Marshal(p.System) + msgs = append(msgs, types.Message{Role: "system", Content: content}) + } + content, _ := json.Marshal(p.Prompt) + msgs = append(msgs, types.Message{Role: "user", Content: content}) + + req := &types.ChatCompletionRequest{ + Model: "auto", + Messages: msgs, + } + if p.JSONMode { + req.ResponseFormat = &types.ResponseFormat{Type: "json_object"} + } + if p.MaxTokens > 0 { + mt := p.MaxTokens + req.MaxTokens = &mt + } + return req +} + +// WriteLiveMarkdown renders the summary as a markdown report suitable for +// pasting into BENCHMARKS.md or a PR description. +func WriteLiveMarkdown(w interface{ Write(p []byte) (int, error) }, s LiveSummary) error { + _, err := fmt.Fprintf(w, "# %s (quality=%s, baseline=%s)\n\n", s.Workload, s.Quality, s.Baseline) + if err != nil { + return err + } + fmt.Fprintf(w, "Problems: %d · Frugal pass: %.1f%% · Baseline pass: %.1f%% · Δ: %+.1fpp\n", + s.ProblemCount, s.FrugalPassRate, s.BaselinePassRate, -s.QualityDeltaPP) + fmt.Fprintf(w, "Cost: frugal $%.4f · baseline $%.4f · savings **%.1f%%**\n\n", + s.FrugalCostUSD, s.BaselineCostUSD, s.SavingsPct) + + fmt.Fprintln(w, "## Model selection") + var names []string + for m := range s.ModelBreakdown { + names = append(names, m) + } + sort.Slice(names, func(i, j int) bool { return s.ModelBreakdown[names[i]] > s.ModelBreakdown[names[j]] }) + for _, m := range names { + fmt.Fprintf(w, "- `%s` × %d\n", m, s.ModelBreakdown[m]) + } + + fmt.Fprintln(w, "\n## Per-problem results") + fmt.Fprintln(w, "| # | Problem | Frugal model | Frugal ✓ | Baseline ✓ |") + fmt.Fprintln(w, "|---|---|---|---|---|") + for i, r := range s.Results { + fmt.Fprintf(w, "| %d | `%s` | `%s` | %s | %s |\n", + i+1, r.ProblemID, r.FrugalModel, checkbox(r.FrugalPass), checkbox(r.BaselinePass)) + } + return nil +} + +func checkbox(b bool) string { + if b { + return "✓" + } + return "✗" +} diff --git a/internal/eval/live_test.go b/internal/eval/live_test.go new file mode 100644 index 0000000..69531ec --- /dev/null +++ b/internal/eval/live_test.go @@ -0,0 +1,128 @@ +package eval + +import ( + "context" + "encoding/json" + "testing" + + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" +) + +// mockProv returns a canned response per model so we can exercise both the +// Frugal and baseline legs deterministically without a network. +type mockProv struct { + responses map[string]string +} + +func (m *mockProv) Name() string { return "mock" } +func (m *mockProv) Models() []string { return []string{"cheap", "expensive"} } +func (m *mockProv) ChatCompletion(ctx context.Context, model string, req *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + text, ok := m.responses[model] + if !ok { + text = "" + } + content, _ := json.Marshal(text) + fr := "stop" + return &types.ChatCompletionResponse{ + ID: "test-" + model, + Object: "chat.completion", + Model: model, + Choices: []types.Choice{{Index: 0, Message: types.Message{Role: "assistant", Content: content}, FinishReason: &fr}}, + Usage: &types.Usage{PromptTokens: 20, CompletionTokens: 5, TotalTokens: 25}, + }, nil +} +func (m *mockProv) ChatCompletionStream(ctx context.Context, model string, req *types.ChatCompletionRequest) (<-chan provider.StreamChunk, error) { + return nil, nil +} + +func TestLiveRunner_ScoresBothLegsAndComputesDelta(t *testing.T) { + reg := provider.NewRegistry() + // cheap answers wrong on math, right on classify; expensive is always right. + reg.Register(&mockProv{responses: map[string]string{ + "cheap": "positive", // correct for classify; wrong for math (no number) + "expensive": "The answer is 42", // contains 42 for math; wrong word for classify + }}) + + models := []router.ModelEntry{ + { + Name: "cheap", Provider: "mock", + CostPer1KInput: 0.0001, CostPer1KOutput: 0.0004, + Reasoning: 0.5, Coding: 0.5, Creative: 0.5, InstructFollowing: 0.5, + MaxContext: 10000, + }, + { + Name: "expensive", Provider: "mock", + CostPer1KInput: 0.003, CostPer1KOutput: 0.015, + Reasoning: 0.95, Coding: 0.95, Creative: 0.95, InstructFollowing: 0.95, + MaxContext: 10000, + }, + } + thresholds := map[string]router.Threshold{ + "cost": {}, + "balanced": {MinReasoning: 0.3, MinCoding: 0.3, MinCreative: 0.3, MinInstructFollowing: 0.3}, + } + + runner := &LiveRunner{ + Router: router.New(models, thresholds), + Classifier: classifier.NewRuleBased(), + Registry: reg, + ModelCosts: map[string]ModelCost{ + "cheap": {InputPer1K: 0.0001, OutputPer1K: 0.0004}, + "expensive": {InputPer1K: 0.003, OutputPer1K: 0.015}, + }, + } + + workload := LiveWorkload{ + Name: "unit", + Baseline: "expensive", + Problems: []Problem{ + {ID: "classify-pos", Prompt: "is it nice?", ExpectedEquals: "positive"}, + {ID: "math-forty-two", Prompt: "What is 6*7?", ExpectedContains: "42"}, + }, + } + + s, err := runner.Run(context.Background(), workload, types.QualityCost) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if s.ProblemCount != 2 { + t.Fatalf("want 2 problems, got %d", s.ProblemCount) + } + // At quality=cost, Frugal should prefer "cheap". + if s.ModelBreakdown["cheap"] != 2 { + t.Errorf("expected Frugal to pick cheap twice, breakdown=%v", s.ModelBreakdown) + } + // Baseline (expensive) contains "42" — correct on math; wrong on classify. + // Cheap returns "positive" — right on classify, wrong on math. + // Each leg scores 1/2 = 50%. + if s.FrugalPassRate != 50.0 { + t.Errorf("expected FrugalPassRate=50, got %.1f", s.FrugalPassRate) + } + if s.BaselinePassRate != 50.0 { + t.Errorf("expected BaselinePassRate=50, got %.1f", s.BaselinePassRate) + } + // Costs computed from Usage × per-token rates. + if s.BaselineCostUSD <= s.FrugalCostUSD { + t.Errorf("expected baseline cost to exceed frugal cost; frugal=%.6f baseline=%.6f", + s.FrugalCostUSD, s.BaselineCostUSD) + } + if s.SavingsPct <= 0 { + t.Errorf("expected positive savings, got %.2f%%", s.SavingsPct) + } +} + +func TestLiveRunner_RejectsUnregisteredBaseline(t *testing.T) { + runner := &LiveRunner{ + Registry: provider.NewRegistry(), + Classifier: classifier.NewRuleBased(), + Router: router.New(nil, nil), + } + _, err := runner.Run(context.Background(), LiveWorkload{Name: "x", Baseline: "nope", Problems: []Problem{{ID: "p1", Prompt: "hi", ExpectedContains: "hi"}}}, types.QualityCost) + if err == nil { + t.Fatalf("expected error when baseline is unregistered") + } +} diff --git a/internal/eval/report.go b/internal/eval/report.go new file mode 100644 index 0000000..aadcc42 --- /dev/null +++ b/internal/eval/report.go @@ -0,0 +1,32 @@ +package eval + +import ( + "fmt" + "io" +) + +// WriteMarkdown renders a Summary as a markdown report: per-query table plus +// an aggregate line. Suitable for pasting into BENCHMARKS.md. +func WriteMarkdown(w io.Writer, s Summary) error { + if _, err := fmt.Fprintf(w, "# Workload: %s (quality=%s, baseline=%s)\n\n", + s.Workload, s.Quality, s.BaselineModel); err != nil { + return err + } + if _, err := fmt.Fprintln(w, "| # | Query | Selected | Provider | Frugal $ | Baseline $ | Savings % |"); err != nil { + return err + } + if _, err := fmt.Fprintln(w, "|---|---|---|---|---|---|---|"); err != nil { + return err + } + for i, r := range s.Results { + if _, err := fmt.Fprintf(w, "| %d | %s | %s | %s | $%.6f | $%.6f | %.1f%% |\n", + i+1, r.Query.Label, r.Decision.SelectedModel, r.Decision.SelectedProvider, + r.FrugalCost, r.BaselineCost, r.SavingsPct); err != nil { + return err + } + } + _, err := fmt.Fprintf(w, + "\n**Total:** Frugal $%.4f vs baseline $%.4f — **%.1f%% savings** across %d queries.\n", + s.TotalFrugal, s.TotalBaseline, s.SavingsPct, s.QueryCount) + return err +} diff --git a/internal/eval/scorer.go b/internal/eval/scorer.go new file mode 100644 index 0000000..774b4ee --- /dev/null +++ b/internal/eval/scorer.go @@ -0,0 +1,160 @@ +package eval + +import ( + "encoding/json" + "fmt" + "regexp" + "strconv" + "strings" +) + +// Scorer judges whether a model response to a benchmark problem is correct. +// Implementations are cheap: exact/substring/numeric/JSON matches run locally +// without spending an LLM-judge call. An LLM-judge scorer can be added later +// as a drop-in implementation. +type Scorer interface { + Name() string + Score(output string) ScoreResult +} + +// ScoreResult is what a scorer returns. Detail is rendered in verbose reports +// so failures are debuggable without re-running the bench. +type ScoreResult struct { + Pass bool + Detail string +} + +// ExactTrimmed matches the output against Expected after trimming whitespace +// from both sides. Case-sensitive. Use for classification labels and +// one-answer math. +type ExactTrimmed struct{ Expected string } + +func (s ExactTrimmed) Name() string { return "exact_trimmed" } +func (s ExactTrimmed) Score(out string) ScoreResult { + got := strings.TrimSpace(out) + if got == s.Expected { + return ScoreResult{Pass: true} + } + return ScoreResult{Pass: false, Detail: fmt.Sprintf("want %q, got %q", s.Expected, got)} +} + +// Substring passes when Expected appears anywhere in the output. CaseFold +// lowercases both sides before comparing — use it for fact-recall where the +// model's answer may be rephrased but must contain the key term. +type Substring struct { + Expected string + CaseFold bool +} + +func (s Substring) Name() string { return "substring" } +func (s Substring) Score(out string) ScoreResult { + haystack, needle := out, s.Expected + if s.CaseFold { + haystack = strings.ToLower(haystack) + needle = strings.ToLower(needle) + } + if strings.Contains(haystack, needle) { + return ScoreResult{Pass: true} + } + return ScoreResult{Pass: false, Detail: fmt.Sprintf("missing %q", s.Expected)} +} + +// ContainsAll passes when every keyword appears in the output (CaseFold +// applies to all). Use for explanations that must hit specific technical +// terms — e.g. "quicksort" answer must mention "pivot" and "partition". +type ContainsAll struct { + Keywords []string + CaseFold bool +} + +func (s ContainsAll) Name() string { return "contains_all" } +func (s ContainsAll) Score(out string) ScoreResult { + haystack := out + if s.CaseFold { + haystack = strings.ToLower(haystack) + } + var missing []string + for _, kw := range s.Keywords { + needle := kw + if s.CaseFold { + needle = strings.ToLower(needle) + } + if !strings.Contains(haystack, needle) { + missing = append(missing, kw) + } + } + if len(missing) == 0 { + return ScoreResult{Pass: true} + } + return ScoreResult{Pass: false, Detail: "missing keywords: " + strings.Join(missing, ", ")} +} + +// JSONHasKeys passes when the output parses as a JSON object and contains +// every required key at the top level. Values aren't type-checked here — +// extend with a schema matcher when the benchmark set needs it. +type JSONHasKeys struct{ RequiredKeys []string } + +func (s JSONHasKeys) Name() string { return "json_has_keys" } +func (s JSONHasKeys) Score(out string) ScoreResult { + // Tolerate fenced markdown around the JSON — common LLM output pattern. + payload := stripJSONFence(out) + var obj map[string]any + if err := json.Unmarshal([]byte(payload), &obj); err != nil { + return ScoreResult{Pass: false, Detail: "not valid JSON: " + err.Error()} + } + var missing []string + for _, k := range s.RequiredKeys { + if _, ok := obj[k]; !ok { + missing = append(missing, k) + } + } + if len(missing) == 0 { + return ScoreResult{Pass: true} + } + return ScoreResult{Pass: false, Detail: "missing keys: " + strings.Join(missing, ", ")} +} + +// Numeric pulls the first number out of the output and compares it to +// Expected within ±Tolerance. Lets models preface "The answer is 42" or +// "≈ 3.14159" without failing on prose wrapping. +type Numeric struct { + Expected float64 + Tolerance float64 +} + +var numberRe = regexp.MustCompile(`-?\d+(?:\.\d+)?`) + +func (s Numeric) Name() string { return "numeric" } +func (s Numeric) Score(out string) ScoreResult { + m := numberRe.FindString(out) + if m == "" { + return ScoreResult{Pass: false, Detail: "no number in output"} + } + got, err := strconv.ParseFloat(m, 64) + if err != nil { + return ScoreResult{Pass: false, Detail: "parse %q: " + err.Error()} + } + diff := got - s.Expected + if diff < 0 { + diff = -diff + } + if diff <= s.Tolerance { + return ScoreResult{Pass: true} + } + return ScoreResult{Pass: false, Detail: fmt.Sprintf("want %g±%g, got %g", s.Expected, s.Tolerance, got)} +} + +// stripJSONFence removes a surrounding ```json … ``` fence if present. +func stripJSONFence(s string) string { + s = strings.TrimSpace(s) + if strings.HasPrefix(s, "```") { + // Skip ``` and optional language tag to end of line. + if nl := strings.Index(s, "\n"); nl > 0 { + s = s[nl+1:] + } + if idx := strings.LastIndex(s, "```"); idx >= 0 { + s = s[:idx] + } + } + return strings.TrimSpace(s) +} diff --git a/internal/eval/scorer_test.go b/internal/eval/scorer_test.go new file mode 100644 index 0000000..c574778 --- /dev/null +++ b/internal/eval/scorer_test.go @@ -0,0 +1,75 @@ +package eval + +import "testing" + +func TestExactTrimmed(t *testing.T) { + s := ExactTrimmed{Expected: "42"} + if !s.Score(" 42 ").Pass { + t.Fatalf("whitespace around exact match should pass") + } + if s.Score("42.0").Pass { + t.Fatalf("42.0 should not exactly equal 42") + } + if s.Score("the answer is 42").Pass { + t.Fatalf("substring should not satisfy exact match") + } +} + +func TestSubstring_CaseFold(t *testing.T) { + s := Substring{Expected: "Paris", CaseFold: true} + if !s.Score("the capital is paris, france").Pass { + t.Fatalf("case-fold substring should match") + } + if s.Score("london").Pass { + t.Fatalf("non-matching substring should fail") + } +} + +func TestContainsAll_ReportsMissing(t *testing.T) { + s := ContainsAll{Keywords: []string{"pivot", "partition", "divide"}, CaseFold: true} + r := s.Score("Quicksort partitions around a pivot element.") + if r.Pass { + t.Fatalf("expected failure when one keyword missing") + } + if r.Detail == "" { + t.Fatalf("expected detail listing missing keywords") + } + r2 := s.Score("Quicksort divides the array by partitioning around a pivot.") + if !r2.Pass { + t.Fatalf("expected pass when all keywords present; detail=%s", r2.Detail) + } +} + +func TestJSONHasKeys_ToleratesMarkdownFence(t *testing.T) { + s := JSONHasKeys{RequiredKeys: []string{"name", "email"}} + fenced := "```json\n{\"name\":\"Jane\",\"email\":\"j@x.co\"}\n```" + if !s.Score(fenced).Pass { + t.Fatalf("fenced JSON should parse and match") + } + missing := s.Score(`{"name":"Jane"}`) + if missing.Pass { + t.Fatalf("should fail when required key is missing") + } + bad := s.Score("not json at all") + if bad.Pass { + t.Fatalf("invalid JSON should not pass") + } +} + +func TestNumeric_WithinTolerance(t *testing.T) { + s := Numeric{Expected: 714, Tolerance: 0} + if !s.Score("The answer is 714.").Pass { + t.Fatalf("exact number in prose should pass") + } + if s.Score("715").Pass { + t.Fatalf("off-by-one with zero tolerance should fail") + } + + pi := Numeric{Expected: 3.14159, Tolerance: 0.001} + if !pi.Score("pi ≈ 3.1416").Pass { + t.Fatalf("within-tolerance float should pass") + } + if pi.Score("pi ≈ 3.15").Pass { + t.Fatalf("outside-tolerance float should fail") + } +} diff --git a/internal/eval/workload.go b/internal/eval/workload.go new file mode 100644 index 0000000..e292720 --- /dev/null +++ b/internal/eval/workload.go @@ -0,0 +1,18 @@ +package eval + +import "github.com/frugalsh/frugal/internal/types" + +// Query is a single prompt to route through the eval harness. +type Query struct { + Label string + Request *types.ChatCompletionRequest +} + +// Workload is a named collection of queries representing a realistic usage profile. +// Real workload definitions live in separate files (e.g. workloads_claude_code.go) +// and can be registered here as the benchmark set grows. +type Workload struct { + Name string + Description string + Queries []Query +} diff --git a/internal/eval/workload_yaml.go b/internal/eval/workload_yaml.go new file mode 100644 index 0000000..21862ec --- /dev/null +++ b/internal/eval/workload_yaml.go @@ -0,0 +1,122 @@ +package eval + +import ( + "bytes" + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +// Problem is one benchmark item: a prompt, optional system prompt, and one +// scorer selected from a small in-tree palette (see scorer.go). Keeping the +// scorer palette small on purpose — LLM-judge scorers can be added later +// behind a type: "judge" branch without breaking existing workloads. +type Problem struct { + ID string `yaml:"id"` + Prompt string `yaml:"prompt"` + System string `yaml:"system,omitempty"` + JSONMode bool `yaml:"json_mode,omitempty"` + MaxTokens int `yaml:"max_tokens,omitempty"` + + // Exactly one of the Expected* fields should be set per problem. The YAML + // loader infers the scorer type from whichever one is non-zero. + ExpectedEquals string `yaml:"expected_equals,omitempty"` + ExpectedContains string `yaml:"expected_contains,omitempty"` + ExpectedContainsAll []string `yaml:"expected_contains_all,omitempty"` + ExpectedKeys []string `yaml:"expected_keys,omitempty"` + ExpectedNumber *float64 `yaml:"expected_number,omitempty"` + + CaseFold bool `yaml:"case_fold,omitempty"` + Tolerance float64 `yaml:"tolerance,omitempty"` +} + +// Scorer builds the appropriate Scorer for this problem. Returns an error if +// zero or more than one Expected* field is set — workloads should fail loudly +// on malformed rows rather than silently scoring every response as passing. +func (p Problem) Scorer() (Scorer, error) { + picks := 0 + if p.ExpectedEquals != "" { + picks++ + } + if p.ExpectedContains != "" { + picks++ + } + if len(p.ExpectedContainsAll) > 0 { + picks++ + } + if len(p.ExpectedKeys) > 0 { + picks++ + } + if p.ExpectedNumber != nil { + picks++ + } + if picks == 0 { + return nil, fmt.Errorf("problem %q has no expected_* scorer field", p.ID) + } + if picks > 1 { + return nil, fmt.Errorf("problem %q has multiple expected_* fields; pick one", p.ID) + } + + switch { + case p.ExpectedEquals != "": + return ExactTrimmed{Expected: p.ExpectedEquals}, nil + case p.ExpectedContains != "": + return Substring{Expected: p.ExpectedContains, CaseFold: p.CaseFold}, nil + case len(p.ExpectedContainsAll) > 0: + return ContainsAll{Keywords: p.ExpectedContainsAll, CaseFold: p.CaseFold}, nil + case len(p.ExpectedKeys) > 0: + return JSONHasKeys{RequiredKeys: p.ExpectedKeys}, nil + case p.ExpectedNumber != nil: + return Numeric{Expected: *p.ExpectedNumber, Tolerance: p.Tolerance}, nil + } + return nil, fmt.Errorf("problem %q: unreachable scorer branch", p.ID) +} + +// LiveWorkload is a YAML-authored set of benchmark problems. Distinct from +// Workload (simulation-only) so the benchmark harness can evolve its schema +// without breaking simulation consumers. +type LiveWorkload struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Baseline string `yaml:"baseline"` + Problems []Problem `yaml:"problems"` +} + +// LoadLiveWorkload reads a YAML workload from disk. All scorers are validated +// up front so a bad row is caught before any API calls happen. +func LoadLiveWorkload(path string) (LiveWorkload, error) { + data, err := os.ReadFile(path) + if err != nil { + return LiveWorkload{}, fmt.Errorf("read workload: %w", err) + } + var w LiveWorkload + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(true) + if err := dec.Decode(&w); err != nil { + return LiveWorkload{}, fmt.Errorf("parse workload: %w", err) + } + if w.Name == "" { + return LiveWorkload{}, fmt.Errorf("workload %q: missing name", path) + } + if w.Baseline == "" { + return LiveWorkload{}, fmt.Errorf("workload %q: missing baseline model", path) + } + if len(w.Problems) == 0 { + return LiveWorkload{}, fmt.Errorf("workload %q: no problems", path) + } + seen := map[string]bool{} + for i, p := range w.Problems { + if p.ID == "" { + return LiveWorkload{}, fmt.Errorf("workload %q: problem %d missing id", path, i) + } + if seen[p.ID] { + return LiveWorkload{}, fmt.Errorf("workload %q: duplicate problem id %q", path, p.ID) + } + seen[p.ID] = true + if _, err := p.Scorer(); err != nil { + return LiveWorkload{}, fmt.Errorf("workload %q: %w", path, err) + } + } + return w, nil +} diff --git a/internal/eval/workload_yaml_test.go b/internal/eval/workload_yaml_test.go new file mode 100644 index 0000000..a5cddaf --- /dev/null +++ b/internal/eval/workload_yaml_test.go @@ -0,0 +1,79 @@ +package eval + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadLiveWorkload_StarterReturnsAllProblems(t *testing.T) { + // Resolve relative to the repo root so tests are runnable from any cwd + // that Go's test harness lands on. + path := filepath.Join("..", "..", "config", "workloads", "starter.yaml") + if _, err := os.Stat(path); err != nil { + t.Skipf("starter workload not found: %v", err) + } + + w, err := LoadLiveWorkload(path) + if err != nil { + t.Fatalf("LoadLiveWorkload: %v", err) + } + if w.Name != "starter" { + t.Errorf("expected name=starter, got %q", w.Name) + } + if w.Baseline == "" { + t.Errorf("expected non-empty baseline model") + } + if len(w.Problems) < 15 { + t.Errorf("expected >= 15 problems in starter workload, got %d", len(w.Problems)) + } + + // Every problem must build a valid scorer. + for _, p := range w.Problems { + if _, err := p.Scorer(); err != nil { + t.Errorf("problem %q scorer build: %v", p.ID, err) + } + } +} + +func TestLoadLiveWorkload_RejectsMultipleExpectedFields(t *testing.T) { + dir := t.TempDir() + bad := filepath.Join(dir, "bad.yaml") + if err := os.WriteFile(bad, []byte(` +name: bad +baseline: mock-model +problems: + - id: p1 + prompt: hello + expected_equals: "42" + expected_contains: "forty-two" +`), 0o644); err != nil { + t.Fatal(err) + } + + if _, err := LoadLiveWorkload(bad); err == nil { + t.Fatalf("expected error when multiple expected_* fields are set") + } +} + +func TestLoadLiveWorkload_RejectsDuplicateIDs(t *testing.T) { + dir := t.TempDir() + bad := filepath.Join(dir, "dup.yaml") + if err := os.WriteFile(bad, []byte(` +name: dup +baseline: mock-model +problems: + - id: same + prompt: hello + expected_equals: hi + - id: same + prompt: world + expected_equals: world +`), 0o644); err != nil { + t.Fatal(err) + } + + if _, err := LoadLiveWorkload(bad); err == nil { + t.Fatalf("expected error on duplicate problem ids") + } +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..99b483a --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,79 @@ +// Package metrics exposes the Prometheus counters and histograms that the +// proxy emits on the hot path. Keep the label cardinality bounded: model and +// provider are controlled (config file); status bucketed by class; no +// raw-user labels ever. +package metrics + +import ( + "net/http" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + RequestsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "frugal_requests_total", + Help: "Chat completion requests by selected model, provider, quality tier, and status class (2xx/4xx/5xx).", + }, []string{"model", "provider", "quality", "status_class"}) + + RequestDurationSeconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "frugal_request_duration_seconds", + Help: "Wall-clock time for each chat completion request, labeled by model/provider/stream-vs-nonstream.", + Buckets: []float64{0.05, 0.1, 0.25, 0.5, 1, 2, 5, 10, 20, 30, 60, 120}, + }, []string{"model", "provider", "stream"}) + + TokensTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "frugal_tokens_total", + Help: "Tokens routed through the proxy, split by direction (prompt/completion).", + }, []string{"model", "provider", "direction"}) + + CostUSDTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "frugal_cost_usd_total", + Help: "Estimated dollar cost of served requests (based on the routing-time estimate).", + }, []string{"model", "provider"}) + + RoutingRelaxedTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "frugal_routing_relaxed_total", + Help: "Count of requests where the router relaxed the requested quality tier to land a model.", + }, []string{"from", "to"}) +) + +// Register wires every Frugal metric into the default registry. Safe to call +// more than once: duplicate registrations become a noop so tests can reuse. +func Register() { + for _, c := range []prometheus.Collector{ + RequestsTotal, RequestDurationSeconds, TokensTotal, CostUSDTotal, RoutingRelaxedTotal, + } { + _ = prometheus.Register(c) // ignore AlreadyRegisteredError + } +} + +// Handler returns the /metrics HTTP handler. +func Handler() http.Handler { + return promhttp.Handler() +} + +// StatusClass bucketizes an HTTP status into 2xx/4xx/5xx so label cardinality +// stays sane (full status would explode the series count on misbehaving +// upstreams). +func StatusClass(status int) string { + switch { + case status >= 500: + return "5xx" + case status >= 400: + return "4xx" + case status >= 300: + return "3xx" + case status >= 200: + return "2xx" + default: + return "unknown" + } +} + +// ObserveDuration records a duration in seconds. +func ObserveDuration(h *prometheus.HistogramVec, model, provider, stream string, d time.Duration) { + h.WithLabelValues(model, provider, stream).Observe(d.Seconds()) +} diff --git a/internal/obs/obs.go b/internal/obs/obs.go new file mode 100644 index 0000000..055963d --- /dev/null +++ b/internal/obs/obs.go @@ -0,0 +1,99 @@ +// Package obs wires structured logging, request IDs, and a small set of +// cross-cutting observability primitives used by the rest of the codebase. +// The public surface is intentionally tiny so we can swap implementations +// (slog handler, ID generator) without touching callers. +package obs + +import ( + "context" + "crypto/rand" + "encoding/base32" + "log/slog" + "os" + "strings" +) + +type ctxKey int + +const ( + requestIDKey ctxKey = iota + loggerKey +) + +// InitLogger configures the process-wide default slog logger from +// FRUGAL_LOG_LEVEL (debug|info|warn|error) and FRUGAL_LOG_FORMAT (text|json). +// Text output preserves the human-readable local-dev experience; json is +// what deployers pipe into collectors. +func InitLogger() *slog.Logger { + level := parseLevel(os.Getenv("FRUGAL_LOG_LEVEL")) + opts := &slog.HandlerOptions{Level: level} + + var handler slog.Handler + switch strings.ToLower(os.Getenv("FRUGAL_LOG_FORMAT")) { + case "json": + handler = slog.NewJSONHandler(os.Stderr, opts) + default: + handler = slog.NewTextHandler(os.Stderr, opts) + } + + logger := slog.New(handler) + slog.SetDefault(logger) + return logger +} + +func parseLevel(s string) slog.Level { + switch strings.ToLower(strings.TrimSpace(s)) { + case "debug": + return slog.LevelDebug + case "warn", "warning": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + return slog.LevelInfo + } +} + +// NewRequestID returns a random 16-byte ID in unpadded base32. Cryptographic +// randomness is overkill for a trace ID but free and avoids collisions +// entirely without a central sequencer. +func NewRequestID() string { + var buf [16]byte + if _, err := rand.Read(buf[:]); err != nil { + // Fall back to deterministic-but-unique ID only in the impossible + // event rand.Read fails. A bad ID is better than a dropped request. + return "req-00000000000000000000000000" + } + return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(buf[:]) +} + +// WithRequestID attaches a request ID to a context. Retrieved via RequestID. +func WithRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +// RequestID returns the request ID stored in ctx, or "" when absent. +func RequestID(ctx context.Context) string { + if id, ok := ctx.Value(requestIDKey).(string); ok { + return id + } + return "" +} + +// WithLogger attaches a pre-scoped slog.Logger to ctx. Handlers pull this +// via L(ctx) so request-scoped attrs (request_id, model, provider) are +// never lost between helpers. +func WithLogger(ctx context.Context, l *slog.Logger) context.Context { + return context.WithValue(ctx, loggerKey, l) +} + +// L returns the request-scoped logger if one was attached via WithLogger, +// otherwise the process default. Safe to call with a nil or background ctx. +func L(ctx context.Context) *slog.Logger { + if ctx != nil { + if l, ok := ctx.Value(loggerKey).(*slog.Logger); ok && l != nil { + return l + } + } + return slog.Default() +} diff --git a/internal/provider/anthropic/anthropic.go b/internal/provider/anthropic/anthropic.go index e71ee34..021bc38 100644 --- a/internal/provider/anthropic/anthropic.go +++ b/internal/provider/anthropic/anthropic.go @@ -1,7 +1,6 @@ package anthropic import ( - "bufio" "bytes" "context" "encoding/json" @@ -17,6 +16,19 @@ import ( const anthropicVersion = "2023-06-01" +const errorBodyLimit = 8 << 10 // 8 KiB + +func readErrorBody(r io.Reader) string { + body, err := io.ReadAll(io.LimitReader(r, errorBodyLimit+1)) + if err != nil { + return "" + } + if len(body) > errorBodyLimit { + return string(body[:errorBodyLimit]) + "... (truncated)" + } + return string(body) +} + type Provider struct { apiKey string baseURL string @@ -29,7 +41,7 @@ func New(apiKey, baseURL string, models []string) *Provider { apiKey: apiKey, baseURL: baseURL, models: models, - client: &http.Client{}, + client: provider.NewHTTPClient(), } } @@ -40,17 +52,38 @@ func (p *Provider) Models() []string { return p.models } // -- Anthropic API types -- type messagesRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []anthropicMsg `json:"messages"` - Stream bool `json:"stream,omitempty"` - Tools []anthropicTool `json:"tools,omitempty"` + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []anthropicMsg `json:"messages"` + Stream bool `json:"stream,omitempty"` + Tools []anthropicTool `json:"tools,omitempty"` } type anthropicMsg struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content []anthropicContent `json:"content"` +} + +// anthropicContent is Anthropic's content block. Emitted block types today: +// text, image, tool_result. tool_use for assistant-origin tool calls is +// handled by the upstream model; Frugal does not synthesize it. +type anthropicContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *anthropicSource `json:"source,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + Content string `json:"content,omitempty"` // tool_result body + // CacheControl is forwarded verbatim so callers can opt into Anthropic + // prompt caching without Frugal stripping the hint. + CacheControl json.RawMessage `json:"cache_control,omitempty"` +} + +type anthropicSource struct { + Type string `json:"type"` // "base64" or "url" + MediaType string `json:"media_type,omitempty"` // required for base64 + Data string `json:"data,omitempty"` // base64 payload + URL string `json:"url,omitempty"` // for type:"url" } type anthropicTool struct { @@ -60,12 +93,12 @@ type anthropicTool struct { } type messagesResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []contentBlock `json:"content"` - Model string `json:"model"` - Usage anthropicUsage `json:"usage"` + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []contentBlock `json:"content"` + Model string `json:"model"` + Usage anthropicUsage `json:"usage"` StopReason string `json:"stop_reason"` } @@ -107,21 +140,35 @@ func translateRequest(req *types.ChatCompletionRequest, model string) *messagesR maxTokens := 4096 if req.MaxTokens != nil { maxTokens = *req.MaxTokens + } else if req.MaxCompletionTokens != nil { + maxTokens = *req.MaxCompletionTokens } ar.MaxTokens = maxTokens for _, msg := range req.Messages { if msg.Role == "system" { - ar.System = msg.ContentString() + ar.System = msg.ContentText() continue } - role := msg.Role - if role == "tool" { - role = "user" // Anthropic handles tool results differently, simplify for now + + // Tool results ride on a user message as a tool_result block so the + // upstream model can correlate them to the prior tool_use. OpenAI + // represents them as role="tool" with tool_call_id. + if msg.Role == "tool" { + ar.Messages = append(ar.Messages, anthropicMsg{ + Role: "user", + Content: []anthropicContent{{ + Type: "tool_result", + ToolUseID: msg.ToolCallID, + Content: msg.ContentText(), + }}, + }) + continue } + ar.Messages = append(ar.Messages, anthropicMsg{ - Role: role, - Content: msg.ContentString(), + Role: msg.Role, + Content: toAnthropicContent(msg), }) } @@ -136,6 +183,58 @@ func translateRequest(req *types.ChatCompletionRequest, model string) *messagesR return ar } +// toAnthropicContent translates an OpenAI message into Anthropic's content +// block array. Text parts become {type:"text"}; image_url parts become +// {type:"image"} with either a base64 or url source depending on the input. +// Per-part cache_control hints forward verbatim so Anthropic prompt-caching +// works without Frugal stripping the marker. +func toAnthropicContent(msg types.Message) []anthropicContent { + parts := msg.ContentParts() + if len(parts) == 0 { + return []anthropicContent{{Type: "text", Text: ""}} + } + out := make([]anthropicContent, 0, len(parts)) + for _, p := range parts { + switch p.Type { + case "", "text": + out = append(out, anthropicContent{Type: "text", Text: p.Text, CacheControl: p.CacheControl}) + case "image_url": + if p.ImageURL == nil { + continue + } + src := imageURLToAnthropicSource(p.ImageURL.URL) + if src == nil { + continue + } + out = append(out, anthropicContent{Type: "image", Source: src, CacheControl: p.CacheControl}) + } + } + if len(out) == 0 { + out = append(out, anthropicContent{Type: "text", Text: ""}) + } + return out +} + +func imageURLToAnthropicSource(url string) *anthropicSource { + const dataPrefix = "data:" + if strings.HasPrefix(url, dataPrefix) { + // data:image/png;base64, + rest := url[len(dataPrefix):] + semi := strings.Index(rest, ";") + comma := strings.Index(rest, ",") + if semi < 0 || comma < 0 || semi > comma { + return nil + } + media := rest[:semi] + data := rest[comma+1:] + return &anthropicSource{Type: "base64", MediaType: media, Data: data} + } + if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") { + return &anthropicSource{Type: "url", URL: url} + } + return nil +} + func translateResponse(ar *messagesResponse) *types.ChatCompletionResponse { content := "" for _, block := range ar.Content { @@ -213,8 +312,7 @@ func (p *Provider) ChatCompletion(ctx context.Context, model string, req *types. defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("anthropic error %d: %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("anthropic error %d: %s", resp.StatusCode, readErrorBody(resp.Body)) } var result messagesResponse @@ -249,8 +347,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * if resp.StatusCode != http.StatusOK { defer resp.Body.Close() - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("anthropic error %d: %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("anthropic error %d: %s", resp.StatusCode, readErrorBody(resp.Body)) } ch := make(chan provider.StreamChunk, 8) @@ -259,7 +356,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * defer resp.Body.Close() chunkID := fmt.Sprintf("chatcmpl-%s", ar.Model) - scanner := bufio.NewScanner(resp.Body) + scanner := provider.NewSSEScanner(resp.Body) for scanner.Scan() { line := scanner.Text() diff --git a/internal/provider/anthropic/anthropic_test.go b/internal/provider/anthropic/anthropic_test.go new file mode 100644 index 0000000..bfd90b3 --- /dev/null +++ b/internal/provider/anthropic/anthropic_test.go @@ -0,0 +1,75 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "github.com/frugalsh/frugal/internal/types" +) + +func TestTranslateRequest_MultimodalImage_ProducesBase64Block(t *testing.T) { + req := &types.ChatCompletionRequest{ + Messages: []types.Message{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"text","text":"describe"}, + {"type":"image_url","image_url":{"url":"data:image/png;base64,AAAA"}} + ]`)}, + }, + } + + ar := translateRequest(req, "claude-sonnet-4-20250514") + + if len(ar.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(ar.Messages)) + } + blocks := ar.Messages[0].Content + if len(blocks) != 2 { + t.Fatalf("expected 2 content blocks, got %d: %+v", len(blocks), blocks) + } + if blocks[0].Type != "text" || blocks[0].Text != "describe" { + t.Fatalf("block 0 wrong: %+v", blocks[0]) + } + if blocks[1].Type != "image" || blocks[1].Source == nil { + t.Fatalf("block 1 missing image source: %+v", blocks[1]) + } + src := blocks[1].Source + if src.Type != "base64" || src.MediaType != "image/png" || src.Data != "AAAA" { + t.Fatalf("base64 source mis-translated: %+v", src) + } +} + +func TestTranslateRequest_MultimodalImage_RemoteURL(t *testing.T) { + req := &types.ChatCompletionRequest{ + Messages: []types.Message{ + {Role: "user", Content: json.RawMessage(`[{"type":"image_url","image_url":{"url":"https://example.com/i.png"}}]`)}, + }, + } + + ar := translateRequest(req, "claude-sonnet-4-20250514") + + if len(ar.Messages) != 1 || len(ar.Messages[0].Content) != 1 { + t.Fatalf("expected single image block, got %+v", ar.Messages) + } + src := ar.Messages[0].Content[0].Source + if src == nil || src.Type != "url" || src.URL != "https://example.com/i.png" { + t.Fatalf("url source mis-translated: %+v", src) + } +} + +func TestTranslateRequest_PlainString_StillProducesTextBlock(t *testing.T) { + req := &types.ChatCompletionRequest{ + Messages: []types.Message{ + {Role: "user", Content: json.RawMessage(`"hello"`)}, + }, + } + + ar := translateRequest(req, "claude-sonnet-4-20250514") + + if len(ar.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(ar.Messages)) + } + blocks := ar.Messages[0].Content + if len(blocks) != 1 || blocks[0].Type != "text" || blocks[0].Text != "hello" { + t.Fatalf("string content did not round-trip as text block: %+v", blocks) + } +} diff --git a/internal/provider/anthropic/error_body_test.go b/internal/provider/anthropic/error_body_test.go new file mode 100644 index 0000000..3bec7ca --- /dev/null +++ b/internal/provider/anthropic/error_body_test.go @@ -0,0 +1,26 @@ +package anthropic + +import ( + "strings" + "testing" +) + +func TestReadErrorBody_TruncatesLargeBody(t *testing.T) { + in := strings.Repeat("a", errorBodyLimit+32) + out := readErrorBody(strings.NewReader(in)) + + if !strings.HasSuffix(out, "... (truncated)") { + t.Fatalf("expected truncation suffix") + } + if len(out) <= errorBodyLimit { + t.Fatalf("expected output to include suffix beyond limit, got len=%d", len(out)) + } +} + +func TestReadErrorBody_SmallBodyUnchanged(t *testing.T) { + in := "provider unavailable" + out := readErrorBody(strings.NewReader(in)) + if out != in { + t.Fatalf("expected %q, got %q", in, out) + } +} diff --git a/internal/provider/google/error_body_test.go b/internal/provider/google/error_body_test.go new file mode 100644 index 0000000..a1d9d61 --- /dev/null +++ b/internal/provider/google/error_body_test.go @@ -0,0 +1,26 @@ +package google + +import ( + "strings" + "testing" +) + +func TestReadErrorBody_TruncatesLargeBody(t *testing.T) { + in := strings.Repeat("a", errorBodyLimit+32) + out := readErrorBody(strings.NewReader(in)) + + if !strings.HasSuffix(out, "... (truncated)") { + t.Fatalf("expected truncation suffix") + } + if len(out) <= errorBodyLimit { + t.Fatalf("expected output to include suffix beyond limit, got len=%d", len(out)) + } +} + +func TestReadErrorBody_SmallBodyUnchanged(t *testing.T) { + in := "provider unavailable" + out := readErrorBody(strings.NewReader(in)) + if out != in { + t.Fatalf("expected %q, got %q", in, out) + } +} diff --git a/internal/provider/google/google.go b/internal/provider/google/google.go index 887c1b2..c6335a6 100644 --- a/internal/provider/google/google.go +++ b/internal/provider/google/google.go @@ -1,7 +1,6 @@ package google import ( - "bufio" "bytes" "context" "encoding/json" @@ -15,6 +14,19 @@ import ( "github.com/frugalsh/frugal/internal/types" ) +const errorBodyLimit = 8 << 10 // 8 KiB + +func readErrorBody(r io.Reader) string { + body, err := io.ReadAll(io.LimitReader(r, errorBodyLimit+1)) + if err != nil { + return "" + } + if len(body) > errorBodyLimit { + return string(body[:errorBodyLimit]) + "... (truncated)" + } + return string(body) +} + type Provider struct { apiKey string baseURL string @@ -27,7 +39,7 @@ func New(apiKey, baseURL string, models []string) *Provider { apiKey: apiKey, baseURL: baseURL, models: models, - client: &http.Client{}, + client: provider.NewHTTPClient(), } } @@ -38,9 +50,24 @@ func (p *Provider) Models() []string { return p.models } // -- Gemini API types -- type generateContentRequest struct { - Contents []geminiContent `json:"contents"` - SystemInstruction *geminiContent `json:"systemInstruction,omitempty"` - GenerationConfig *generationConfig `json:"generationConfig,omitempty"` + Contents []geminiContent `json:"contents"` + SystemInstruction *geminiContent `json:"systemInstruction,omitempty"` + GenerationConfig *generationConfig `json:"generationConfig,omitempty"` + Tools []geminiToolDecl `json:"tools,omitempty"` + // CachedContent points at a Gemini cached content resource. Forwarded + // verbatim when the client supplies `frugal_cached_content` in their + // Metadata so Gemini context caching works without Frugal stripping it. + CachedContent string `json:"cachedContent,omitempty"` +} + +type geminiToolDecl struct { + FunctionDeclarations []geminiFuncDecl `json:"functionDeclarations"` +} + +type geminiFuncDecl struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` } type geminiContent struct { @@ -49,15 +76,33 @@ type geminiContent struct { } type geminiPart struct { - Text string `json:"text,omitempty"` + Text string `json:"text,omitempty"` + InlineData *geminiInlineData `json:"inlineData,omitempty"` + FileData *geminiFileData `json:"fileData,omitempty"` + FunctionCall *geminiFuncCall `json:"functionCall,omitempty"` +} + +type geminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type geminiFileData struct { + MimeType string `json:"mimeType,omitempty"` + FileURI string `json:"fileUri"` +} + +type geminiFuncCall struct { + Name string `json:"name"` + Args json.RawMessage `json:"args,omitempty"` } type generationConfig struct { - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"topP,omitempty"` - MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` - ResponseMimeType string `json:"responseMimeType,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` + ResponseMimeType string `json:"responseMimeType,omitempty"` } type generateContentResponse struct { @@ -84,7 +129,7 @@ func translateRequest(req *types.ChatCompletionRequest) *generateContentRequest for _, msg := range req.Messages { if msg.Role == "system" { gr.SystemInstruction = &geminiContent{ - Parts: []geminiPart{{Text: msg.ContentString()}}, + Parts: toGeminiParts(msg), } continue } @@ -95,10 +140,18 @@ func translateRequest(req *types.ChatCompletionRequest) *generateContentRequest } gr.Contents = append(gr.Contents, geminiContent{ Role: role, - Parts: []geminiPart{{Text: msg.ContentString()}}, + Parts: toGeminiParts(msg), }) } + if tools := toGeminiTools(req.Tools); len(tools) > 0 { + gr.Tools = tools + } + + if cached := cachedContentFromMetadata(req.Metadata); cached != "" { + gr.CachedContent = cached + } + gc := &generationConfig{} hasConfig := false if req.Temperature != nil { @@ -112,6 +165,9 @@ func translateRequest(req *types.ChatCompletionRequest) *generateContentRequest if req.MaxTokens != nil { gc.MaxOutputTokens = req.MaxTokens hasConfig = true + } else if req.MaxCompletionTokens != nil { + gc.MaxOutputTokens = req.MaxCompletionTokens + hasConfig = true } if req.ResponseFormat != nil && req.ResponseFormat.Type == "json_object" { gc.ResponseMimeType = "application/json" @@ -124,6 +180,93 @@ func translateRequest(req *types.ChatCompletionRequest) *generateContentRequest return gr } +// toGeminiParts converts OpenAI content parts to Gemini parts. Text parts map +// to {text}; image_url with a data URL maps to {inlineData}; image_url with a +// remote URL maps to {fileData}. Empty content produces a single empty-text +// part so the request remains valid. +func toGeminiParts(msg types.Message) []geminiPart { + parts := msg.ContentParts() + if len(parts) == 0 { + return []geminiPart{{Text: ""}} + } + out := make([]geminiPart, 0, len(parts)) + for _, p := range parts { + switch p.Type { + case "", "text": + out = append(out, geminiPart{Text: p.Text}) + case "image_url": + if p.ImageURL == nil { + continue + } + if strings.HasPrefix(p.ImageURL.URL, "data:") { + if part := dataURLToInlinePart(p.ImageURL.URL); part != nil { + out = append(out, *part) + } + continue + } + if strings.HasPrefix(p.ImageURL.URL, "http://") || strings.HasPrefix(p.ImageURL.URL, "https://") { + out = append(out, geminiPart{FileData: &geminiFileData{FileURI: p.ImageURL.URL}}) + } + } + } + if len(out) == 0 { + return []geminiPart{{Text: ""}} + } + return out +} + +// cachedContentFromMetadata reads a Gemini cached-content resource name from +// the request's OpenAI-style metadata field. Clients opt in with +// `metadata: {"frugal_cached_content": "projects/.../cachedContents/..."}`; +// absent or non-string values skip the field entirely. +func cachedContentFromMetadata(raw []byte) string { + if len(raw) == 0 { + return "" + } + var m map[string]string + if err := json.Unmarshal(raw, &m); err != nil { + return "" + } + return m["frugal_cached_content"] +} + +// toGeminiTools maps OpenAI tool declarations to a single Gemini tool entry +// containing all function declarations. Gemini supports exactly one tools[] +// element carrying many functionDeclarations. +func toGeminiTools(tools []types.Tool) []geminiToolDecl { + if len(tools) == 0 { + return nil + } + decls := make([]geminiFuncDecl, 0, len(tools)) + for _, t := range tools { + if t.Type != "" && t.Type != "function" { + continue + } + decls = append(decls, geminiFuncDecl{ + Name: t.Function.Name, + Description: t.Function.Description, + Parameters: t.Function.Parameters, + }) + } + if len(decls) == 0 { + return nil + } + return []geminiToolDecl{{FunctionDeclarations: decls}} +} + +func dataURLToInlinePart(url string) *geminiPart { + rest := url[len("data:"):] + semi := strings.Index(rest, ";") + comma := strings.Index(rest, ",") + if semi < 0 || comma < 0 || semi > comma { + return nil + } + return &geminiPart{InlineData: &geminiInlineData{ + MimeType: rest[:semi], + Data: rest[comma+1:], + }} +} + func translateResponse(gr *generateContentResponse, model string) *types.ChatCompletionResponse { resp := &types.ChatCompletionResponse{ ID: fmt.Sprintf("chatcmpl-gemini-%d", time.Now().UnixNano()), @@ -134,15 +277,36 @@ func translateResponse(gr *generateContentResponse, model string) *types.ChatCom for i, cand := range gr.Candidates { content := "" + var toolCalls []types.ToolCall for _, part := range cand.Content.Parts { - content += part.Text + if part.Text != "" { + content += part.Text + } + if part.FunctionCall != nil { + args := string(part.FunctionCall.Args) + if args == "" { + args = "{}" + } + toolCalls = append(toolCalls, types.ToolCall{ + ID: fmt.Sprintf("call_%s_%d_%d", part.FunctionCall.Name, i, len(toolCalls)), + Type: "function", + Function: types.ToolCallFunction{ + Name: part.FunctionCall.Name, + Arguments: args, + }, + }) + } } finishReason := mapFinishReason(cand.FinishReason) + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } resp.Choices = append(resp.Choices, types.Choice{ Index: i, Message: types.Message{ - Role: "assistant", - Content: mustMarshal(content), + Role: "assistant", + Content: mustMarshal(content), + ToolCalls: toolCalls, }, FinishReason: &finishReason, }) @@ -201,8 +365,7 @@ func (p *Provider) ChatCompletion(ctx context.Context, model string, req *types. defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("gemini error %d: %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("gemini error %d: %s", resp.StatusCode, readErrorBody(resp.Body)) } var result generateContentResponse @@ -235,8 +398,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * if resp.StatusCode != http.StatusOK { defer resp.Body.Close() - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("gemini error %d: %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("gemini error %d: %s", resp.StatusCode, readErrorBody(resp.Body)) } ch := make(chan provider.StreamChunk, 8) @@ -244,7 +406,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * defer close(ch) defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) + scanner := provider.NewSSEScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { @@ -259,8 +421,26 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * for _, cand := range gr.Candidates { text := "" + var toolDeltas []types.ToolCallDelta for _, part := range cand.Content.Parts { - text += part.Text + if part.Text != "" { + text += part.Text + } + if part.FunctionCall != nil { + args := string(part.FunctionCall.Args) + if args == "" { + args = "{}" + } + toolDeltas = append(toolDeltas, types.ToolCallDelta{ + Index: len(toolDeltas), + ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, len(toolDeltas)), + Type: "function", + Function: &types.ToolCallFunction{ + Name: part.FunctionCall.Name, + Arguments: args, + }, + }) + } } ch <- provider.StreamChunk{ Data: &types.ChatCompletionChunk{ @@ -271,7 +451,10 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * Choices: []types.ChunkChoice{ { Index: 0, - Delta: types.MessageDelta{Content: text}, + Delta: types.MessageDelta{ + Content: text, + ToolCalls: toolDeltas, + }, }, }, }, diff --git a/internal/provider/google/google_test.go b/internal/provider/google/google_test.go new file mode 100644 index 0000000..77b744f --- /dev/null +++ b/internal/provider/google/google_test.go @@ -0,0 +1,101 @@ +package google + +import ( + "encoding/json" + "testing" + + "github.com/frugalsh/frugal/internal/types" +) + +func TestTranslateRequest_PassesToolsThrough(t *testing.T) { + params := json.RawMessage(`{"type":"object","properties":{"q":{"type":"string"}}}`) + req := &types.ChatCompletionRequest{ + Model: "gemini-2.5-pro", + Messages: []types.Message{ + {Role: "user", Content: json.RawMessage(`"search the web"`)}, + }, + Tools: []types.Tool{ + {Type: "function", Function: types.ToolFunction{ + Name: "web_search", Description: "search", Parameters: params, + }}, + }, + } + + gr := translateRequest(req) + + if len(gr.Tools) != 1 { + t.Fatalf("expected 1 Gemini tool entry, got %d", len(gr.Tools)) + } + decls := gr.Tools[0].FunctionDeclarations + if len(decls) != 1 || decls[0].Name != "web_search" { + t.Fatalf("expected web_search declaration, got %+v", decls) + } + if string(decls[0].Parameters) != string(params) { + t.Fatalf("parameters not preserved: got %s", string(decls[0].Parameters)) + } +} + +func TestTranslateRequest_MultimodalImage_ProducesInlineData(t *testing.T) { + req := &types.ChatCompletionRequest{ + Messages: []types.Message{ + {Role: "user", Content: json.RawMessage(`[ + {"type":"text","text":"what is this"}, + {"type":"image_url","image_url":{"url":"data:image/png;base64,AAAA"}} + ]`)}, + }, + } + + gr := translateRequest(req) + + if len(gr.Contents) != 1 { + t.Fatalf("expected 1 content block, got %d", len(gr.Contents)) + } + parts := gr.Contents[0].Parts + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %+v", len(parts), parts) + } + if parts[0].Text != "what is this" { + t.Fatalf("first part text = %q", parts[0].Text) + } + if parts[1].InlineData == nil { + t.Fatalf("second part missing inlineData: %+v", parts[1]) + } + if parts[1].InlineData.MimeType != "image/png" || parts[1].InlineData.Data != "AAAA" { + t.Fatalf("inlineData mis-translated: %+v", parts[1].InlineData) + } +} + +func TestTranslateResponse_FunctionCall_MapsToToolCalls(t *testing.T) { + gr := &generateContentResponse{ + Candidates: []candidate{ + { + Content: geminiContent{ + Parts: []geminiPart{ + {FunctionCall: &geminiFuncCall{ + Name: "web_search", + Args: json.RawMessage(`{"q":"go"}`), + }}, + }, + }, + FinishReason: "STOP", + }, + }, + } + + resp := translateResponse(gr, "gemini-2.5-pro") + + if len(resp.Choices) != 1 { + t.Fatalf("expected 1 choice, got %d", len(resp.Choices)) + } + choice := resp.Choices[0] + if choice.FinishReason == nil || *choice.FinishReason != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, got %v", choice.FinishReason) + } + if len(choice.Message.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %+v", choice.Message.ToolCalls) + } + tc := choice.Message.ToolCalls[0] + if tc.Function.Name != "web_search" || tc.Function.Arguments != `{"q":"go"}` { + t.Fatalf("tool call mis-translated: %+v", tc) + } +} diff --git a/internal/provider/httpclient.go b/internal/provider/httpclient.go new file mode 100644 index 0000000..5f290f0 --- /dev/null +++ b/internal/provider/httpclient.go @@ -0,0 +1,37 @@ +package provider + +import ( + "net" + "net/http" + "time" +) + +const ( + defaultDialTimeout = 10 * time.Second + defaultKeepAlive = 30 * time.Second + defaultTLSHandshakeTimeout = 10 * time.Second + defaultResponseHeaderTimeout = 30 * time.Second + defaultExpectContinueTimeout = 1 * time.Second + defaultIdleConnTimeout = 90 * time.Second +) + +// NewHTTPClient returns an HTTP client with defensive network timeouts suitable +// for outbound provider API calls. It intentionally avoids Client.Timeout so +// streaming responses can remain open while still bounding connection/setup. +func NewHTTPClient() *http.Client { + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: defaultDialTimeout, + KeepAlive: defaultKeepAlive, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: defaultIdleConnTimeout, + TLSHandshakeTimeout: defaultTLSHandshakeTimeout, + ExpectContinueTimeout: defaultExpectContinueTimeout, + ResponseHeaderTimeout: defaultResponseHeaderTimeout, + } + + return &http.Client{Transport: transport} +} diff --git a/internal/provider/httpclient_test.go b/internal/provider/httpclient_test.go new file mode 100644 index 0000000..46729d2 --- /dev/null +++ b/internal/provider/httpclient_test.go @@ -0,0 +1,28 @@ +package provider + +import ( + "net/http" + "testing" +) + +func TestNewHTTPClient_ConfiguresDefensiveTransport(t *testing.T) { + client := NewHTTPClient() + if client == nil { + t.Fatal("expected non-nil client") + } + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("expected *http.Transport, got %T", client.Transport) + } + + if transport.ResponseHeaderTimeout != defaultResponseHeaderTimeout { + t.Fatalf("expected ResponseHeaderTimeout %s, got %s", defaultResponseHeaderTimeout, transport.ResponseHeaderTimeout) + } + if transport.TLSHandshakeTimeout != defaultTLSHandshakeTimeout { + t.Fatalf("expected TLSHandshakeTimeout %s, got %s", defaultTLSHandshakeTimeout, transport.TLSHandshakeTimeout) + } + if transport.IdleConnTimeout != defaultIdleConnTimeout { + t.Fatalf("expected IdleConnTimeout %s, got %s", defaultIdleConnTimeout, transport.IdleConnTimeout) + } +} diff --git a/internal/provider/openai/error_body_test.go b/internal/provider/openai/error_body_test.go new file mode 100644 index 0000000..1ec93c6 --- /dev/null +++ b/internal/provider/openai/error_body_test.go @@ -0,0 +1,26 @@ +package openai + +import ( + "strings" + "testing" +) + +func TestReadErrorBody_TruncatesLargeBody(t *testing.T) { + in := strings.Repeat("a", errorBodyLimit+32) + out := readErrorBody(strings.NewReader(in)) + + if !strings.HasSuffix(out, "... (truncated)") { + t.Fatalf("expected truncation suffix, got %q", out[len(out)-20:]) + } + if len(out) <= errorBodyLimit { + t.Fatalf("expected output to include suffix beyond limit, got len=%d", len(out)) + } +} + +func TestReadErrorBody_SmallBodyUnchanged(t *testing.T) { + in := "provider unavailable" + out := readErrorBody(strings.NewReader(in)) + if out != in { + t.Fatalf("expected %q, got %q", in, out) + } +} diff --git a/internal/provider/openai/openai.go b/internal/provider/openai/openai.go index 86e207b..7dd429c 100644 --- a/internal/provider/openai/openai.go +++ b/internal/provider/openai/openai.go @@ -1,7 +1,6 @@ package openai import ( - "bufio" "bytes" "context" "encoding/json" @@ -14,6 +13,19 @@ import ( "github.com/frugalsh/frugal/internal/types" ) +const errorBodyLimit = 8 << 10 // 8 KiB + +func readErrorBody(r io.Reader) string { + body, err := io.ReadAll(io.LimitReader(r, errorBodyLimit+1)) + if err != nil { + return "" + } + if len(body) > errorBodyLimit { + return string(body[:errorBodyLimit]) + "... (truncated)" + } + return string(body) +} + type Provider struct { apiKey string baseURL string @@ -26,7 +38,7 @@ func New(apiKey, baseURL string, models []string) *Provider { apiKey: apiKey, baseURL: baseURL, models: models, - client: &http.Client{}, + client: provider.NewHTTPClient(), } } @@ -58,8 +70,7 @@ func (p *Provider) ChatCompletion(ctx context.Context, model string, req *types. defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("openai error %d: %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("openai error %d: %s", resp.StatusCode, readErrorBody(resp.Body)) } var result types.ChatCompletionResponse @@ -94,8 +105,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * if resp.StatusCode != http.StatusOK { defer resp.Body.Close() - respBody, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("openai error %d: %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("openai error %d: %s", resp.StatusCode, readErrorBody(resp.Body)) } ch := make(chan provider.StreamChunk, 8) @@ -103,7 +113,7 @@ func (p *Provider) ChatCompletionStream(ctx context.Context, model string, req * defer close(ch) defer resp.Body.Close() - scanner := bufio.NewScanner(resp.Body) + scanner := provider.NewSSEScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { diff --git a/internal/provider/openai/openai_stream_test.go b/internal/provider/openai/openai_stream_test.go new file mode 100644 index 0000000..d8da943 --- /dev/null +++ b/internal/provider/openai/openai_stream_test.go @@ -0,0 +1,52 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/frugalsh/frugal/internal/types" +) + +func TestChatCompletionStream_AllowsLargeSSELines(t *testing.T) { + large := strings.Repeat("x", 70*1024) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "data: {\"id\":\"chatcmpl-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o-mini\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"%s\"}}]}\n\n", large) + fmt.Fprint(w, "data: [DONE]\n\n") + })) + defer ts.Close() + + p := New("test-key", ts.URL, []string{"gpt-4o-mini"}) + p.client = ts.Client() + + ch, err := p.ChatCompletionStream(context.Background(), "gpt-4o-mini", &types.ChatCompletionRequest{}) + if err != nil { + t.Fatalf("ChatCompletionStream returned error: %v", err) + } + + var gotData, gotDone bool + for chunk := range ch { + if chunk.Err != nil { + t.Fatalf("stream chunk error: %v", chunk.Err) + } + if chunk.Data != nil { + gotData = true + } + if chunk.Done { + gotDone = true + } + } + + if !gotData { + t.Fatal("expected at least one data chunk") + } + if !gotDone { + t.Fatal("expected done chunk") + } +} + diff --git a/internal/provider/retry.go b/internal/provider/retry.go new file mode 100644 index 0000000..e99d447 --- /dev/null +++ b/internal/provider/retry.go @@ -0,0 +1,106 @@ +package provider + +import ( + "context" + "errors" + "regexp" + "strconv" + "strings" + "time" + + "github.com/frugalsh/frugal/internal/obs" + "github.com/frugalsh/frugal/internal/types" +) + +// retryBackoff is capped and intentionally tight: Frugal proxies user-facing +// latency-sensitive calls, so we'd rather surface a failure quickly than +// stretch a bad upstream window. Streaming is never retried past the first +// handshake — once bytes flow, the router owns fallback. +var retryBackoff = []time.Duration{50 * time.Millisecond, 200 * time.Millisecond, 800 * time.Millisecond} + +// WithRetry wraps a Provider so non-streaming ChatCompletion calls retry +// on transient upstream failures (429 / 502 / 503 / 504), honoring +// Retry-After when the provider includes it in the error string. +func WithRetry(p Provider) Provider { + return &retryingProvider{inner: p} +} + +type retryingProvider struct{ inner Provider } + +func (r *retryingProvider) Name() string { return r.inner.Name() } +func (r *retryingProvider) Models() []string { return r.inner.Models() } + +func (r *retryingProvider) ChatCompletion(ctx context.Context, model string, req *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + var lastErr error + for attempt := 0; attempt <= len(retryBackoff); attempt++ { + resp, err := r.inner.ChatCompletion(ctx, model, req) + if err == nil { + return resp, nil + } + lastErr = err + if !isRetryable(err) || attempt == len(retryBackoff) { + return nil, err + } + delay := retryBackoff[attempt] + if hint := parseRetryAfter(err); hint > 0 && hint < 30*time.Second { + delay = hint + } + obs.L(ctx).Warn("upstream retry", + "provider", r.inner.Name(), + "model", model, + "attempt", attempt+1, + "delay_ms", delay.Milliseconds(), + "err", err, + ) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + } + return nil, lastErr +} + +func (r *retryingProvider) ChatCompletionStream(ctx context.Context, model string, req *types.ChatCompletionRequest) (<-chan StreamChunk, error) { + // Streams are not retried here. Fallback chain in the proxy handler owns + // the handshake-level retry; once a chunk has been written to the client, + // retry is no longer safe. + return r.inner.ChatCompletionStream(ctx, model, req) +} + +// isRetryable classifies provider errors. Matches on the error string because +// every provider.ChatCompletion error is formatted as ` error : +// ` today; parsing the numeric status keeps us from coupling to the +// concrete error types. +func isRetryable(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + msg := strings.ToLower(err.Error()) + for _, needle := range []string{" 429", " 502", " 503", " 504", "rate limit", "temporarily unavailable"} { + if strings.Contains(msg, needle) { + return true + } + } + return false +} + +var retryAfterRe = regexp.MustCompile(`retry[- ]after[: ]+(\d+)`) + +func parseRetryAfter(err error) time.Duration { + if err == nil { + return 0 + } + m := retryAfterRe.FindStringSubmatch(strings.ToLower(err.Error())) + if len(m) != 2 { + return 0 + } + secs, perr := strconv.Atoi(m[1]) + if perr != nil || secs <= 0 { + return 0 + } + return time.Duration(secs) * time.Second +} diff --git a/internal/provider/retry_test.go b/internal/provider/retry_test.go new file mode 100644 index 0000000..2b98a7a --- /dev/null +++ b/internal/provider/retry_test.go @@ -0,0 +1,102 @@ +package provider + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/frugalsh/frugal/internal/types" +) + +type mockInner struct { + calls atomic.Int32 + errs []error + response *types.ChatCompletionResponse +} + +func (m *mockInner) Name() string { return "mock" } +func (m *mockInner) Models() []string { return nil } + +func (m *mockInner) ChatCompletion(ctx context.Context, model string, req *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + n := int(m.calls.Add(1)) - 1 + if n < len(m.errs) && m.errs[n] != nil { + return nil, m.errs[n] + } + return m.response, nil +} + +func (m *mockInner) ChatCompletionStream(ctx context.Context, model string, req *types.ChatCompletionRequest) (<-chan StreamChunk, error) { + return nil, errors.New("stream not tested") +} + +func TestWithRetry_RetriesOn503UntilSuccess(t *testing.T) { + // Shrink backoff so the test completes instantly. + orig := retryBackoff + retryBackoff = []time.Duration{time.Millisecond, time.Millisecond, time.Millisecond} + defer func() { retryBackoff = orig }() + + inner := &mockInner{ + errs: []error{errors.New("openai error 503: unavailable"), errors.New("openai error 503: unavailable")}, + response: &types.ChatCompletionResponse{ID: "ok"}, + } + p := WithRetry(inner) + + resp, err := p.ChatCompletion(context.Background(), "m", &types.ChatCompletionRequest{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.ID != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if got := inner.calls.Load(); got != 3 { + t.Fatalf("expected 3 calls (2 retries + success), got %d", got) + } +} + +func TestWithRetry_GivesUpAfterMaxAttempts(t *testing.T) { + orig := retryBackoff + retryBackoff = []time.Duration{time.Millisecond, time.Millisecond, time.Millisecond} + defer func() { retryBackoff = orig }() + + inner := &mockInner{ + errs: []error{ + errors.New("openai error 429: rate limit"), + errors.New("openai error 429: rate limit"), + errors.New("openai error 429: rate limit"), + errors.New("openai error 429: rate limit"), + }, + } + p := WithRetry(inner) + + _, err := p.ChatCompletion(context.Background(), "m", &types.ChatCompletionRequest{}) + if err == nil { + t.Fatalf("expected error after giving up") + } + if got := inner.calls.Load(); got != 4 { + t.Fatalf("expected 4 total attempts (initial + 3 retries), got %d", got) + } +} + +func TestWithRetry_DoesNotRetryNonTransient(t *testing.T) { + inner := &mockInner{ + errs: []error{errors.New("openai error 400: bad request")}, + } + p := WithRetry(inner) + + _, err := p.ChatCompletion(context.Background(), "m", &types.ChatCompletionRequest{}) + if err == nil { + t.Fatalf("expected error") + } + if got := inner.calls.Load(); got != 1 { + t.Fatalf("expected exactly 1 call for non-retryable error, got %d", got) + } +} + +func TestParseRetryAfter_ExtractsSeconds(t *testing.T) { + err := errors.New("openai error 429: rate limit exceeded, retry-after: 5") + if got := parseRetryAfter(err); got != 5*time.Second { + t.Fatalf("parseRetryAfter = %s, want 5s", got) + } +} diff --git a/internal/provider/sse.go b/internal/provider/sse.go new file mode 100644 index 0000000..5df7803 --- /dev/null +++ b/internal/provider/sse.go @@ -0,0 +1,18 @@ +package provider + +import ( + "bufio" + "io" +) + +const maxSSELineBytes = 1024 * 1024 // 1 MiB + +// NewSSEScanner returns a scanner configured for larger-than-default SSE lines. +// Provider APIs can emit large JSON chunks that exceed bufio.Scanner's 64 KiB default. +func NewSSEScanner(r io.Reader) *bufio.Scanner { + s := bufio.NewScanner(r) + buf := make([]byte, 64*1024) + s.Buffer(buf, maxSSELineBytes) + return s +} + diff --git a/internal/provider/sse_test.go b/internal/provider/sse_test.go new file mode 100644 index 0000000..04c9d2e --- /dev/null +++ b/internal/provider/sse_test.go @@ -0,0 +1,33 @@ +package provider + +import ( + "bytes" + "strings" + "testing" +) + +func TestNewSSEScanner_AllowsLargeLineBeyondDefaultScannerLimit(t *testing.T) { + payload := "data: " + strings.Repeat("a", 100*1024) + "\n\n" + s := NewSSEScanner(bytes.NewBufferString(payload)) + + if !s.Scan() { + t.Fatalf("expected scanner to read large SSE line, err=%v", s.Err()) + } + + if got := s.Text(); got != strings.TrimSuffix(payload, "\n\n") { + t.Fatalf("unexpected scanned text length: got=%d", len(got)) + } +} + +func TestNewSSEScanner_ErrorsWhenLineExceedsConfiguredMax(t *testing.T) { + payload := "data: " + strings.Repeat("b", maxSSELineBytes+1) + "\n\n" + s := NewSSEScanner(bytes.NewBufferString(payload)) + + if s.Scan() { + t.Fatal("expected scan to fail for oversized SSE line") + } + + if err := s.Err(); err == nil { + t.Fatal("expected scanner error for oversized SSE line") + } +} diff --git a/internal/proxy/auth_test.go b/internal/proxy/auth_test.go new file mode 100644 index 0000000..90704aa --- /dev/null +++ b/internal/proxy/auth_test.go @@ -0,0 +1,114 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newTestOKHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`ok`)) + }) +} + +func TestAuthMiddleware_NoTokenIsNoOp(t *testing.T) { + h := AuthMiddleware("")(newTestOKHandler()) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 with empty token, got %d", rec.Code) + } +} + +func TestAuthMiddleware_RejectsMissingHeader(t *testing.T) { + h := AuthMiddleware("secret-token")(newTestOKHandler()) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + h.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 without header, got %d", rec.Code) + } + if got := rec.Header().Get("WWW-Authenticate"); got == "" { + t.Fatalf("expected WWW-Authenticate challenge, got empty") + } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("expected JSON error body, got %q", rec.Body.String()) + } +} + +func TestAuthMiddleware_RejectsWrongToken(t *testing.T) { + h := AuthMiddleware("secret-token")(newTestOKHandler()) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + req.Header.Set("Authorization", "Bearer nope") + h.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 with wrong token, got %d", rec.Code) + } +} + +func TestAuthMiddleware_AcceptsCorrectToken(t *testing.T) { + h := AuthMiddleware("secret-token")(newTestOKHandler()) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + req.Header.Set("Authorization", "Bearer secret-token") + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 with correct token, got %d", rec.Code) + } +} + +func TestAuthMiddleware_CaseInsensitiveBearerPrefix(t *testing.T) { + h := AuthMiddleware("secret")(newTestOKHandler()) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + req.Header.Set("Authorization", "bearer secret") + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 with lowercase bearer, got %d", rec.Code) + } +} + +func TestRateLimitMiddleware_TrivialRpsDisables(t *testing.T) { + h := RateLimitMiddleware(0, 0)(newTestOKHandler()) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 with disabled limiter, got %d", rec.Code) + } +} + +func TestRateLimitMiddleware_RejectsOverBurst(t *testing.T) { + h := RateLimitMiddleware(1, 1)(newTestOKHandler()) + + // First request consumes the single burst token. + { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + h.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("first request: expected 200, got %d", rec.Code) + } + } + + // Second request in rapid succession is rejected. + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + h.ServeHTTP(rec, req) + if rec.Code != http.StatusTooManyRequests { + t.Fatalf("second request: expected 429, got %d", rec.Code) + } + if !strings.Contains(rec.Body.String(), "rate_limited") { + t.Fatalf("expected rate_limited code in body, got %q", rec.Body.String()) + } + if rec.Header().Get("Retry-After") == "" { + t.Fatalf("expected Retry-After header") + } +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 488df0a..db28831 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -1,65 +1,201 @@ package proxy import ( + "context" "encoding/json" + "errors" "fmt" - "log" + "io" "net/http" + "os" + "strconv" + "strings" "sync" "time" + "github.com/go-chi/chi/v5" + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/metrics" + "github.com/frugalsh/frugal/internal/obs" "github.com/frugalsh/frugal/internal/provider" "github.com/frugalsh/frugal/internal/router" "github.com/frugalsh/frugal/internal/types" + "github.com/frugalsh/frugal/internal/usecase" +) + +const ( + maxFallbackAttempts = 3 + defaultMaxCostPerRequestUSD = 1.0 ) +// maxCostPerRequestUSD reads the per-request spend cap once per process. +// A non-positive value disables the cap. +var maxCostPerRequestUSD = func() float64 { + raw := os.Getenv("FRUGAL_MAX_COST_PER_REQUEST_USD") + if raw == "" { + return defaultMaxCostPerRequestUSD + } + v, err := strconv.ParseFloat(raw, 64) + if err != nil || v < 0 { + return defaultMaxCostPerRequestUSD + } + return v +}() + +const defaultDecisionBufferSize = 1000 + // Handler serves the OpenAI-compatible API endpoints. type Handler struct { classifier classifier.Classifier router *router.Router registry *provider.Registry + // useCases is optional — a nil or empty registry disables use-case + // routing and the /v1/bundles endpoints. The handler still works + // exactly as before in that mode. + useCases *usecase.Registry - // Ring buffer of recent routing decisions for /v1/routing/explain - mu sync.Mutex - decisions []types.RoutingDecision - decisionIdx int - lastDecision *types.RoutingDecision + // Decision storage: the hot path posts to decisionCh (non-blocking send + // with a drop-on-full policy so a slow /routing/explain consumer never + // back-pressures chat requests). A single background goroutine drains + // the channel into the ring buffer under mu. + decisionCh chan types.RoutingDecision + mu sync.Mutex + decisions []types.RoutingDecision + decisionIdx int + lastDecision *types.RoutingDecision } func NewHandler(cls classifier.Classifier, rtr *router.Router, reg *provider.Registry) *Handler { - return &Handler{ + return NewHandlerWithUseCases(cls, rtr, reg, nil) +} + +// NewHandlerWithUseCases is the same as NewHandler but wires in a +// use-case registry. Passing nil preserves the legacy (chat-routing-only) +// behavior. +func NewHandlerWithUseCases(cls classifier.Classifier, rtr *router.Router, reg *provider.Registry, uc *usecase.Registry) *Handler { + size := envIntOrDefault("FRUGAL_DECISION_BUFFER", defaultDecisionBufferSize) + if size <= 0 { + size = defaultDecisionBufferSize + } + h := &Handler{ classifier: cls, router: rtr, registry: reg, + useCases: uc, + decisionCh: make(chan types.RoutingDecision, size), decisions: make([]types.RoutingDecision, 100), } + go h.drainDecisions() + return h } +// drainDecisions runs for the life of the handler, pumping decisions from the +// hot-path channel into the ring buffer. Runs on a single goroutine so the +// mutex never contends with request handling. +func (h *Handler) drainDecisions() { + for d := range h.decisionCh { + h.mu.Lock() + h.decisions[h.decisionIdx%len(h.decisions)] = d + h.decisionIdx++ + last := d + h.lastDecision = &last + h.mu.Unlock() + } +} + +// envIntOrDefault mirrors the CLI helper so the package is self-contained. +// Duplicated here rather than exported because the CLI version logs via slog +// which would introduce a cycle if imported. +func envIntOrDefault(key string, fallback int) int { + if s, ok := lookupEnv(key); ok { + if v, err := strconv.Atoi(s); err == nil && v > 0 { + return v + } + } + return fallback +} + +// lookupEnv is a tiny shim so tests can stub os.Getenv without a global. +var lookupEnv = func(key string) (string, bool) { + v, ok := os.LookupEnv(key) + return v, ok +} + +// allowedFallbackModels returns a set of registered model names for +// allowlisting caller-supplied fallback chains. +func (h *Handler) allowedFallbackModels() map[string]struct{} { + models := h.registry.AllModels() + set := make(map[string]struct{}, len(models)) + for _, m := range models { + set[m] = struct{}{} + } + return set +} + +// recordDecision enqueues d for the background drain. The send is +// non-blocking: a slow drain or a packed channel drops the decision rather +// than stalling the hot path, which is the right trade-off — losing an +// observability point is cheaper than losing request latency. func (h *Handler) recordDecision(d types.RoutingDecision) { - h.mu.Lock() - defer h.mu.Unlock() - h.decisions[h.decisionIdx%len(h.decisions)] = d - h.decisionIdx++ - h.lastDecision = &d + select { + case h.decisionCh <- d: + default: + } } +const maxChatCompletionsBodyBytes int64 = 1 << 20 // 1 MiB + // ChatCompletions handles POST /v1/chat/completions func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { - var req types.ChatCompletionRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + req, err := decodeChatCompletionRequest(w, r) + if err != nil { writeError(w, http.StatusBadRequest, "invalid request body: "+err.Error()) return } quality := QualityFromContext(r.Context()) fallbacks := FallbacksFromContext(r.Context()) + useCaseID := UseCaseFromContext(r.Context()) var decision types.RoutingDecision var prov provider.Provider - // Model pinning: if model is not "auto" and not empty, try to resolve directly - if req.Model != "" && req.Model != "auto" { + // Use-case routing: when X-Frugal-Use-Case is set, look up the bundle's + // chat model for the requested quality tier and pin to it. Unknown use + // case → 400 so caller typos surface immediately. Unknown tier or an + // unregistered bundle model → fall through to the classifier rather than + // hard-failing (degrades gracefully if a use case references a model + // whose provider key isn't configured). + if useCaseID != "" { + if h.useCases == nil || h.useCases.Len() == 0 { + writeError(w, http.StatusBadRequest, "X-Frugal-Use-Case set but no use cases are configured on this server") + return + } + if _, ok := h.useCases.Get(useCaseID); !ok { + known := strings.Join(h.useCases.IDs(), ", ") + writeError(w, http.StatusBadRequest, fmt.Sprintf("unknown use case %q; known: %s", useCaseID, known)) + return + } + if bundle, ok := h.useCases.Bundle(useCaseID, string(quality)); ok && bundle.Chat != "" { + if p, err := h.registry.Resolve(bundle.Chat); err == nil { + prov = p + decision = types.RoutingDecision{ + SelectedModel: bundle.Chat, + SelectedProvider: p.Name(), + Quality: string(quality), + Pinned: true, + Reason: fmt.Sprintf("pinned by use case %q at %s tier: %s", + useCaseID, quality, strings.TrimSpace(bundle.Reason)), + } + w.Header().Set("X-Frugal-Use-Case", useCaseID) + } + } + } + + // Model pinning: if model is not "auto" and not empty, try to resolve directly. + // Skipped if use-case routing already resolved a provider. + if prov == nil && req.Model != "" && req.Model != "auto" { if p, err := h.registry.Resolve(req.Model); err == nil { prov = p decision = types.RoutingDecision{ @@ -74,7 +210,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { // Route via classifier if not pinned if prov == nil { - features := h.classifier.Classify(&req) + features := h.classifier.Classify(req) decision = h.router.Route(features, quality, fallbacks) if decision.SelectedModel == "" { @@ -90,24 +226,87 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { } } + // Per-request cost cap. Skip when the caller pinned a model (they know + // what they asked for) and when the cap is disabled. The router already + // records decision.EstimatedCost, so this is effectively free. + if !decision.Pinned && maxCostPerRequestUSD > 0 && decision.EstimatedCost > maxCostPerRequestUSD { + obs.L(r.Context()).Warn("rejecting request over cost cap", + "estimated_cost_usd", decision.EstimatedCost, + "cap_usd", maxCostPerRequestUSD, + "model", decision.SelectedModel, + ) + writeError(w, http.StatusPaymentRequired, "estimated request cost exceeds configured cap") + return + } + h.recordDecision(decision) // Add routing info header w.Header().Set("X-Frugal-Model", decision.SelectedModel) w.Header().Set("X-Frugal-Provider", decision.SelectedProvider) + if decision.RelaxedFrom != "" { + w.Header().Set("X-Frugal-Relaxed-From", decision.RelaxedFrom) + metrics.RoutingRelaxedTotal.WithLabelValues(decision.RelaxedFrom, decision.Quality).Inc() + } + start := time.Now() + streamLabel := "nonstream" + if req.Stream { + streamLabel = "stream" + } + sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} if req.Stream { - h.handleStream(w, r, prov, decision, &req, fallbacks) + h.handleStream(sw, r, prov, decision, req, fallbacks) } else { - h.handleNonStream(w, r, prov, decision, &req, fallbacks) + h.handleNonStream(sw, r, prov, decision, req, fallbacks) + } + metrics.RequestsTotal.WithLabelValues( + decision.SelectedModel, decision.SelectedProvider, decision.Quality, metrics.StatusClass(sw.status), + ).Inc() + metrics.ObserveDuration(metrics.RequestDurationSeconds, decision.SelectedModel, decision.SelectedProvider, streamLabel, time.Since(start)) + if decision.EstimatedCost > 0 { + metrics.CostUSDTotal.WithLabelValues(decision.SelectedModel, decision.SelectedProvider).Add(decision.EstimatedCost) } } +func decodeChatCompletionRequest(w http.ResponseWriter, r *http.Request) (*types.ChatCompletionRequest, error) { + r.Body = http.MaxBytesReader(w, r.Body, maxChatCompletionsBodyBytes) + defer r.Body.Close() + + // Unknown fields are accepted and forwarded to the OpenAI provider verbatim. + // Real OpenAI SDKs routinely send fields the proxy would otherwise reject + // (parallel_tool_calls, seed, reasoning_effort, service_tier, etc.), which + // would break Frugal's "no code changes" promise. + dec := json.NewDecoder(r.Body) + + var req types.ChatCompletionRequest + if err := dec.Decode(&req); err != nil { + var syntaxErr *json.SyntaxError + var typeErr *json.UnmarshalTypeError + switch { + case errors.As(err, &syntaxErr): + return nil, fmt.Errorf("malformed JSON") + case errors.Is(err, io.EOF): + return nil, fmt.Errorf("empty request body") + case errors.As(err, &typeErr): + return nil, fmt.Errorf("invalid value for field %q", typeErr.Field) + default: + return nil, err + } + } + + if err := dec.Decode(&struct{}{}); err != io.EOF { + return nil, fmt.Errorf("request body must contain a single JSON object") + } + + return &req, nil +} + func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, prov provider.Provider, decision types.RoutingDecision, req *types.ChatCompletionRequest, fallbacks []string) { resp, err := prov.ChatCompletion(r.Context(), decision.SelectedModel, req) if err != nil { // Try fallback chain - for _, fb := range fallbacks { + for _, fb := range boundedFallbacks(fallbacks, decision.SelectedModel, h.allowedFallbackModels()) { fbProv, fbErr := h.registry.Resolve(fb) if fbErr != nil { continue @@ -116,10 +315,11 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, prov p if err == nil { break } - log.Printf("fallback %s failed: %v", fb, err) + obs.L(r.Context()).Warn("fallback failed", "model", fb, "err", err) } if err != nil { - writeError(w, http.StatusBadGateway, "upstream error: "+err.Error()) + obs.L(r.Context()).Error("upstream error", "model", decision.SelectedModel, "err", err) + writeError(w, http.StatusBadGateway, sanitizedUpstreamMessage(err)) return } } @@ -128,30 +328,126 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, prov p json.NewEncoder(w).Encode(resp) } +// sanitizedUpstreamMessage returns a stable, operator-safe summary of an +// upstream failure. The full error — which can include the provider's +// response body and, in pathological cases, echoed request data — is logged +// but never written to the wire. +func sanitizedUpstreamMessage(err error) string { + if err == nil { + return "upstream error" + } + msg := strings.ToLower(err.Error()) + switch { + case strings.Contains(msg, "context deadline exceeded"), strings.Contains(msg, "timeout"): + return "upstream timeout" + case strings.Contains(msg, "429"), strings.Contains(msg, "rate limit"): + return "upstream rate limited" + case strings.Contains(msg, "401"), strings.Contains(msg, "403"): + return "upstream rejected credentials" + case strings.Contains(msg, "503"), strings.Contains(msg, "502"), strings.Contains(msg, "504"): + return "upstream unavailable" + default: + return "upstream error" + } +} + func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, prov provider.Provider, decision types.RoutingDecision, req *types.ChatCompletionRequest, fallbacks []string) { - ch, err := prov.ChatCompletionStream(r.Context(), decision.SelectedModel, req) + ch, first, err := openStreamWithFirstChunk(r.Context(), prov, decision.SelectedModel, req) if err != nil { - // Try fallback chain - for _, fb := range fallbacks { + // Handshake or first-chunk failure: walk the fallback chain. Once the + // first chunk has been written to the client, any further error is + // surfaced in-band (see streaming.go) — retry is no longer safe. + for _, fb := range boundedFallbacks(fallbacks, decision.SelectedModel, h.allowedFallbackModels()) { fbProv, fbErr := h.registry.Resolve(fb) if fbErr != nil { continue } - ch, err = fbProv.ChatCompletionStream(r.Context(), fb, req) + ch, first, err = openStreamWithFirstChunk(r.Context(), fbProv, fb, req) if err == nil { break } - log.Printf("fallback stream %s failed: %v", fb, err) + obs.L(r.Context()).Warn("fallback stream failed", "model", fb, "err", err) } if err != nil { - writeError(w, http.StatusBadGateway, "upstream stream error: "+err.Error()) + obs.L(r.Context()).Error("upstream stream error", "model", decision.SelectedModel, "err", err) + writeError(w, http.StatusBadGateway, sanitizedUpstreamMessage(err)) return } } - if err := streamResponse(w, ch); err != nil { - log.Printf("stream error: %v", err) + if err := streamResponseWithFirst(r.Context(), w, first, ch); err != nil { + obs.L(r.Context()).Warn("stream write error", "err", err) + } +} + +// openStreamWithFirstChunk opens an upstream stream AND reads the first chunk +// synchronously so handshake-success-but-immediate-Err is still handled by +// the fallback chain. If the first chunk carries a Done (empty stream) that's +// still considered a success so the DONE terminator reaches the client. +func openStreamWithFirstChunk(ctx context.Context, prov provider.Provider, model string, req *types.ChatCompletionRequest) (<-chan provider.StreamChunk, *provider.StreamChunk, error) { + ch, err := prov.ChatCompletionStream(ctx, model, req) + if err != nil { + return nil, nil, err + } + select { + case first, ok := <-ch: + if !ok { + // Upstream closed the channel without emitting anything. + // Treat as handshake failure so fallback runs. + return nil, nil, fmt.Errorf("upstream closed stream before first chunk") + } + if first.Err != nil { + return nil, nil, first.Err + } + return ch, &first, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } +} + +// boundedFallbacks trims the caller-supplied fallback chain to registered +// models only, deduplicated, capped at maxFallbackAttempts, and skipping the +// routed model. Allow-listing against the registry prevents a client from +// crafting an `X-Frugal-Fallback` header that steers traffic to an expensive +// model (or a never-configured one) the operator did not authorize. +func boundedFallbacks(fallbacks []string, selectedModel string, allowed map[string]struct{}) []string { + if len(fallbacks) == 0 { + return nil + } + + bounded := make([]string, 0, maxFallbackAttempts) + seen := make(map[string]struct{}, len(fallbacks)) + for _, fb := range fallbacks { + if len(bounded) >= maxFallbackAttempts { + break + } + + trimmed := strings.TrimSpace(fb) + if trimmed == "" { + continue + } + + if strings.EqualFold(trimmed, selectedModel) { + continue + } + + key := strings.ToLower(trimmed) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + + if allowed != nil { + if _, ok := allowed[trimmed]; !ok { + obs.L(context.TODO()).Warn("ignoring unregistered fallback", "model", trimmed) + continue + } + } + + bounded = append(bounded, trimmed) } + + return bounded } // ListModels handles GET /v1/models @@ -190,6 +486,82 @@ func (h *Handler) ListModels(w http.ResponseWriter, r *http.Request) { } // RoutingExplain handles GET /v1/routing/explain +// ListBundles handles GET /v1/bundles — returns the set of known use cases +// with their bundles at every tier. Frontend-friendly for a "what can I +// route?" UI or a CLI lister. +func (h *Handler) ListBundles(w http.ResponseWriter, r *http.Request) { + if h.useCases == nil || h.useCases.Len() == 0 { + writeError(w, http.StatusNotFound, "no use cases configured") + return + } + type bundleOut struct { + Chat string `json:"chat"` + Search string `json:"search,omitempty"` + Rerank string `json:"rerank,omitempty"` + Reason string `json:"reason,omitempty"` + } + type caseOut struct { + ID string `json:"id"` + Description string `json:"description"` + Source string `json:"source"` + AsOf string `json:"as_of"` + Confidence string `json:"confidence"` + Bundles map[string]bundleOut `json:"bundles"` + } + out := make([]caseOut, 0, h.useCases.Len()) + for _, id := range h.useCases.IDs() { + uc, _ := h.useCases.Get(id) + bundles := map[string]bundleOut{} + for tier, b := range uc.Bundles { + bundles[tier] = bundleOut{Chat: b.Chat, Search: b.Search, Rerank: b.Rerank, Reason: b.Reason} + } + out = append(out, caseOut{ + ID: uc.ID, Description: uc.Description, Source: uc.Source, + AsOf: uc.AsOf, Confidence: uc.Confidence, Bundles: bundles, + }) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"data": out}) +} + +// GetBundle handles GET /v1/bundles/{use-case}?quality=TIER — returns the +// (capability → model) map for one use case at one tier. Tier defaults to +// balanced when the query param is absent. +func (h *Handler) GetBundle(w http.ResponseWriter, r *http.Request) { + if h.useCases == nil || h.useCases.Len() == 0 { + writeError(w, http.StatusNotFound, "no use cases configured") + return + } + id := chi.URLParam(r, "useCase") + uc, ok := h.useCases.Get(id) + if !ok { + known := strings.Join(h.useCases.IDs(), ", ") + writeError(w, http.StatusNotFound, fmt.Sprintf("unknown use case %q; known: %s", id, known)) + return + } + tier := r.URL.Query().Get("quality") + if tier == "" { + tier = "balanced" + } + bundle, ok := uc.Bundles[tier] + if !ok { + writeError(w, http.StatusNotFound, fmt.Sprintf("use case %q has no %q tier", id, tier)) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "use_case": uc.ID, + "quality": tier, + "chat": bundle.Chat, + "search": bundle.Search, + "rerank": bundle.Rerank, + "reason": strings.TrimSpace(bundle.Reason), + "source": uc.Source, + "as_of": uc.AsOf, + "confidence": uc.Confidence, + }) +} + func (h *Handler) RoutingExplain(w http.ResponseWriter, r *http.Request) { h.mu.Lock() d := h.lastDecision diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index b6719af..3819287 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -4,10 +4,12 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -22,13 +24,29 @@ type mockProvider struct { name string models []string response *types.ChatCompletionResponse + chatErr error streamErr error + + mu sync.Mutex + chatCalls int + streamCalls int + lastChatModel string + lastStreamModel string } func (m *mockProvider) Name() string { return m.name } func (m *mockProvider) Models() []string { return m.models } func (m *mockProvider) ChatCompletion(ctx context.Context, model string, req *types.ChatCompletionRequest) (*types.ChatCompletionResponse, error) { + m.mu.Lock() + m.chatCalls++ + m.lastChatModel = model + m.mu.Unlock() + + if m.chatErr != nil { + return nil, m.chatErr + } + if m.response != nil { return m.response, nil } @@ -51,6 +69,11 @@ func (m *mockProvider) ChatCompletion(ctx context.Context, model string, req *ty } func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, req *types.ChatCompletionRequest) (<-chan provider.StreamChunk, error) { + m.mu.Lock() + m.streamCalls++ + m.lastStreamModel = model + m.mu.Unlock() + if m.streamErr != nil { return nil, m.streamErr } @@ -89,6 +112,18 @@ func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, r return ch, nil } +func (m *mockProvider) ChatCallCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.chatCalls +} + +func (m *mockProvider) LastChatModel() string { + m.mu.Lock() + defer m.mu.Unlock() + return m.lastChatModel +} + func setupHandler() (*Handler, *httptest.Server) { reg := provider.NewRegistry() mock := &mockProvider{ @@ -108,7 +143,7 @@ func setupHandler() (*Handler, *httptest.Server) { Name: "mock-premium", Provider: "mock", CostPer1KInput: 0.003, CostPer1KOutput: 0.015, Reasoning: 0.95, Coding: 0.93, Creative: 0.90, InstructFollowing: 0.95, - ToolUse: true, JSONMode: true, MaxContext: 200000, + ToolUse: true, JSONMode: true, Vision: true, MaxContext: 200000, }, } thresholds := map[string]router.Threshold{ @@ -225,6 +260,127 @@ func TestChatCompletions_ModelPinning(t *testing.T) { } } +func TestChatCompletions_AcceptsUnknownFields(t *testing.T) { + // Real OpenAI SDKs send fields Frugal does not explicitly model + // (parallel_tool_calls, seed, reasoning_effort, service_tier, etc.). + // The proxy must accept them so the "no code changes" promise holds. + _, ts := setupHandler() + defer ts.Close() + + body := []byte(`{"model":"auto","messages":[{"role":"user","content":"hello"}],"unexpected":true,"seed":42,"parallel_tool_calls":false}`) + resp, err := http.Post(ts.URL+"/v1/chat/completions", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200, got %d: %s", resp.StatusCode, string(b)) + } +} + +func TestChatCompletions_RejectsOversizedBody(t *testing.T) { + _, ts := setupHandler() + defer ts.Close() + + oversized := bytes.Repeat([]byte("a"), int(maxChatCompletionsBodyBytes)+1) + body := []byte(`{"model":"auto","messages":[{"role":"user","content":"`) + body = append(body, oversized...) + body = append(body, []byte(`"}]}`)...) + + resp, err := http.Post(ts.URL+"/v1/chat/completions", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 400, got %d: %s", resp.StatusCode, string(b)) + } +} + +func TestChatCompletions_RelaxedFromHeader_EmittedWhenDowngraded(t *testing.T) { + // Force a downgrade: no model clears the "high" threshold for a simple + // English prompt (both mock models have reasoning below the high bar + // of 0.88 on non-coding/non-math queries, except mock-premium at 0.95). + // The premium model does clear it, so we instead construct a setup that + // cannot clear high but can clear balanced. + reg := newStubMockRegistry() + models := []router.ModelEntry{ + { + Name: "cheap", Provider: "mock", + CostPer1KInput: 0.001, CostPer1KOutput: 0.002, + Reasoning: 0.70, Coding: 0.70, Creative: 0.70, InstructFollowing: 0.70, + ToolUse: true, JSONMode: true, MaxContext: 100000, + }, + } + thresholds := map[string]router.Threshold{ + "high": {MinReasoning: 0.99, MinCoding: 0.99, MinCreative: 0.99, MinInstructFollowing: 0.99}, + "balanced": {MinReasoning: 0.60, MinCoding: 0.60, MinCreative: 0.60, MinInstructFollowing: 0.60}, + "cost": {}, + } + cls := classifier.NewRuleBased() + rtr := router.New(models, thresholds) + h := NewHandler(cls, rtr, reg) + + mux := http.NewServeMux() + mux.HandleFunc("POST /v1/chat/completions", h.ChatCompletions) + ts := httptest.NewServer(HeaderExtractionMiddleware(mux)) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("Hello")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Quality", "high") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if got := resp.Header.Get("X-Frugal-Relaxed-From"); got != "high" { + t.Fatalf("expected X-Frugal-Relaxed-From=high, got %q", got) + } +} + +// newStubMockRegistry returns a registry with a single mock model used when +// the default setupHandler wiring is too opinionated for a specific test. +func newStubMockRegistry() *provider.Registry { + reg := provider.NewRegistry() + mock := &mockProvider{name: "mock", models: []string{"cheap"}} + reg.Register(mock) + return reg +} + +func TestChatCompletions_MultimodalContent_IsAccepted(t *testing.T) { + // Router accepts array-typed Content and forwards the request. The + // classifier must still see a meaningful text feature via ContentText; + // the provider must still be invoked. + _, ts := setupHandler() + defer ts.Close() + + body := []byte(`{"model":"auto","messages":[{"role":"user","content":[ + {"type":"text","text":"what is in this picture"}, + {"type":"image_url","image_url":{"url":"data:image/png;base64,AAAA"}} + ]}]}`) + resp, err := http.Post(ts.URL+"/v1/chat/completions", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 200 for multimodal request, got %d: %s", resp.StatusCode, string(b)) + } +} + func TestChatCompletions_QualityHeader(t *testing.T) { _, ts := setupHandler() defer ts.Close() @@ -332,3 +488,100 @@ func mustMarshalJSON(v any) json.RawMessage { b, _ := json.Marshal(v) return b } + +func TestChatCompletions_FallbackAttemptsAreBounded(t *testing.T) { + reg := provider.NewRegistry() + + failing := &mockProvider{name: "failing", models: []string{"primary", "fb1", "fb2", "fb3", "fb4"}, chatErr: errors.New("boom")} + reg.Register(failing) + + models := []router.ModelEntry{{ + Name: "primary", Provider: "failing", + CostPer1KInput: 0.0001, CostPer1KOutput: 0.0002, + Reasoning: 0.8, Coding: 0.8, Creative: 0.8, InstructFollowing: 0.8, + ToolUse: true, JSONMode: true, MaxContext: 128000, + }} + thresholds := map[string]router.Threshold{ + "balanced": {MinReasoning: 0.1, MinCoding: 0.1, MinCreative: 0.1, MinInstructFollowing: 0.1}, + } + + h := NewHandler(classifier.NewRuleBased(), router.New(models, thresholds), reg) + mux := http.NewServeMux() + mux.HandleFunc("POST /v1/chat/completions", h.ChatCompletions) + ts := httptest.NewServer(HeaderExtractionMiddleware(mux)) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("Hello")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Fallback", "fb1,fb2,fb3,fb4") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 502, got %d: %s", resp.StatusCode, string(b)) + } + + // 1 primary attempt + maxFallbackAttempts fallbacks. + wantCalls := 1 + maxFallbackAttempts + if got := failing.ChatCallCount(); got != wantCalls { + t.Fatalf("expected %d total attempts, got %d", wantCalls, got) + } + + if got := failing.LastChatModel(); got != "fb3" { + t.Fatalf("expected last attempted fallback model fb3, got %s", got) + } +} + +func TestBoundedFallbacks_SkipsSelectedModelAndDuplicates(t *testing.T) { + // nil allowed => no allowlist, so all non-empty deduped entries pass. + got := boundedFallbacks([]string{" gpt-4o ", "gpt-4o", "", "claude-3-5-sonnet", "CLAUDE-3-5-SONNET", "gemini-2.5-flash"}, "gpt-4o", nil) + want := []string{"claude-3-5-sonnet", "gemini-2.5-flash"} + + if len(got) != len(want) { + t.Fatalf("expected %d fallbacks, got %d (%v)", len(want), len(got), got) + } + + for i := range want { + if got[i] != want[i] { + t.Fatalf("expected fallback %d = %s, got %s", i, want[i], got[i]) + } + } +} + +func TestBoundedFallbacks_AllowlistRejectsUnknownModels(t *testing.T) { + allowed := map[string]struct{}{"mock-cheap": {}, "mock-premium": {}} + got := boundedFallbacks([]string{"mock-cheap", "claude-opus", "mock-premium"}, "", allowed) + want := []string{"mock-cheap", "mock-premium"} + if len(got) != len(want) { + t.Fatalf("expected %v, got %v", want, got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("entry %d: expected %s, got %s", i, want[i], got[i]) + } + } +} + +func TestBoundedFallbacks_DedupesBeforeApplyingAttemptLimit(t *testing.T) { + got := boundedFallbacks([]string{"a", "a", "b", "c", "d"}, "", nil) + want := []string{"a", "b", "c"} + + if len(got) != len(want) { + t.Fatalf("expected %d fallbacks, got %d (%v)", len(want), len(got), got) + } + + for i := range want { + if got[i] != want[i] { + t.Fatalf("expected fallback %d = %s, got %s", i, want[i], got[i]) + } + } +} diff --git a/internal/proxy/middleware.go b/internal/proxy/middleware.go index c86a351..7ff6224 100644 --- a/internal/proxy/middleware.go +++ b/internal/proxy/middleware.go @@ -2,11 +2,16 @@ package proxy import ( "context" - "log" + "crypto/subtle" + "encoding/json" "net/http" + "runtime/debug" "strings" "time" + "golang.org/x/time/rate" + + "github.com/frugalsh/frugal/internal/obs" "github.com/frugalsh/frugal/internal/types" ) @@ -15,6 +20,7 @@ type contextKey string const ( qualityKey contextKey = "frugal_quality" fallbackKey contextKey = "frugal_fallback" + useCaseKey contextKey = "frugal_use_case" ) // QualityFromContext extracts the quality threshold from the request context. @@ -33,13 +39,124 @@ func FallbacksFromContext(ctx context.Context) []string { return nil } -// HeaderExtractionMiddleware extracts X-Frugal-* headers into the request context. +// UseCaseFromContext extracts the caller-declared use case (from +// X-Frugal-Use-Case header). Returns "" when the header was absent — the +// handler then falls through to non-use-case routing. +func UseCaseFromContext(ctx context.Context) string { + if v, ok := ctx.Value(useCaseKey).(string); ok { + return v + } + return "" +} + +// RequestIDMiddleware propagates or generates an X-Request-ID header, attaches +// it to the request context, and echoes it on the response. The value flows +// through obs.L so every downstream log line (including panics) can be tied +// back to a single request. +func RequestIDMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := strings.TrimSpace(r.Header.Get("X-Request-ID")) + if id == "" || len(id) > 128 { + id = obs.NewRequestID() + } + w.Header().Set("X-Request-ID", id) + ctx := obs.WithRequestID(r.Context(), id) + ctx = obs.WithLogger(ctx, obs.L(ctx).With("request_id", id)) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// RateLimitMiddleware enforces a global token-bucket on the proxy's serve +// entrypoints. rps <= 0 disables the limiter entirely (local dev). Exceeded +// requests receive a 429 with a stable error body and no upstream call is +// issued, protecting the operator's provider keys from loops or abuse. +func RateLimitMiddleware(rps, burst int) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if rps <= 0 { + return next + } + if burst < rps { + burst = rps + } + limiter := rate.NewLimiter(rate.Limit(rps), burst) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !limiter.Allow() { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "rate limit exceeded", + "type": "frugal_rate_limit_error", + "code": "rate_limited", + }, + }) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// AuthMiddleware gates the proxy behind a shared bearer token. When the token +// is empty, the middleware is a no-op (local single-user deployments). When +// set, requests must carry `Authorization: Bearer `; the comparison is +// constant-time. Missing or mismatched tokens return 401 with a stable error +// shape; the request body and headers are never logged. +func AuthMiddleware(token string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if token == "" { + return next + } + want := []byte(token) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got := bearerFromHeader(r.Header.Get("Authorization")) + if got == "" || subtle.ConstantTimeCompare([]byte(got), want) != 1 { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("WWW-Authenticate", `Bearer realm="frugal"`) + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "missing or invalid authorization", + "type": "frugal_auth_error", + "code": "unauthorized", + }, + }) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func bearerFromHeader(h string) string { + const prefix = "Bearer " + if len(h) <= len(prefix) { + return "" + } + // Case-insensitive prefix match per RFC 6750. + if !strings.EqualFold(h[:len(prefix)], prefix) { + return "" + } + return strings.TrimSpace(h[len(prefix):]) +} + +// HeaderExtractionMiddleware extracts X-Frugal-* headers into the request +// context. Unknown X-Frugal-Quality values return 400 up front so typos +// surface to the caller rather than silently coercing to balanced. func HeaderExtractionMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() if q := r.Header.Get("X-Frugal-Quality"); q != "" { - ctx = context.WithValue(ctx, qualityKey, types.ParseQualityThreshold(q)) + qt, ok := types.ParseQualityThreshold(q) + if !ok { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"X-Frugal-Quality must be one of: high, balanced, cost","type":"frugal_error","code":"invalid_quality"}}`)) + return + } + ctx = context.WithValue(ctx, qualityKey, qt) } else { ctx = context.WithValue(ctx, qualityKey, types.QualityBalanced) } @@ -52,17 +169,55 @@ func HeaderExtractionMiddleware(next http.Handler) http.Handler { ctx = context.WithValue(ctx, fallbackKey, parts) } + // Use case header is validated against the registry by the handler, + // not here — middleware shouldn't need the registry reference. + if uc := strings.TrimSpace(r.Header.Get("X-Frugal-Use-Case")); uc != "" { + ctx = context.WithValue(ctx, useCaseKey, uc) + } + next.ServeHTTP(w, r.WithContext(ctx)) }) } -// LoggingMiddleware logs request method, path, status, and duration. +// RecoverMiddleware catches panics from handlers and returns a structured 500. +func RecoverMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + obs.L(r.Context()).Error("panic recovered", + "method", r.Method, + "path", r.URL.Path, + "panic", rec, + "stack", string(debug.Stack()), + ) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "internal server error", + "type": "frugal_error", + }, + }) + } + }() + + next.ServeHTTP(w, r) + }) +} + +// LoggingMiddleware emits a single structured log line per request with +// method, path, status, duration, and any attrs added downstream. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(sw, r) - log.Printf("%s %s %d %s", r.Method, r.URL.Path, sw.status, time.Since(start).Round(time.Millisecond)) + obs.L(r.Context()).Info("request", + "method", r.Method, + "path", r.URL.Path, + "status", sw.status, + "duration_ms", time.Since(start).Milliseconds(), + ) }) } diff --git a/internal/proxy/middleware_test.go b/internal/proxy/middleware_test.go new file mode 100644 index 0000000..4b15abb --- /dev/null +++ b/internal/proxy/middleware_test.go @@ -0,0 +1,62 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestRecoverMiddleware_RecoversPanic(t *testing.T) { + h := RecoverMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("boom") + })) + + req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", rr.Code) + } + + if ct := rr.Header().Get("Content-Type"); !strings.Contains(ct, "application/json") { + t.Fatalf("expected json content-type, got %q", ct) + } + + var body struct { + Error struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { + t.Fatalf("decode body: %v", err) + } + + if body.Error.Message != "internal server error" { + t.Fatalf("unexpected error message: %q", body.Error.Message) + } + if body.Error.Type != "frugal_error" { + t.Fatalf("unexpected error type: %q", body.Error.Type) + } +} + +func TestRecoverMiddleware_PassesThrough(t *testing.T) { + h := RecoverMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("ok")) + })) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusTeapot { + t.Fatalf("expected 418, got %d", rr.Code) + } + if rr.Body.String() != "ok" { + t.Fatalf("expected body ok, got %q", rr.Body.String()) + } +} diff --git a/internal/proxy/requestid_test.go b/internal/proxy/requestid_test.go new file mode 100644 index 0000000..acbd19b --- /dev/null +++ b/internal/proxy/requestid_test.go @@ -0,0 +1,74 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/frugalsh/frugal/internal/obs" +) + +func TestRequestIDMiddleware_GeneratesWhenMissing(t *testing.T) { + var captured string + handler := RequestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = obs.RequestID(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(rec, req) + + if captured == "" { + t.Fatalf("expected a generated request ID in context") + } + if got := rec.Header().Get("X-Request-ID"); got == "" { + t.Fatalf("expected X-Request-ID response header") + } + if rec.Header().Get("X-Request-ID") != captured { + t.Fatalf("header %q != context %q", rec.Header().Get("X-Request-ID"), captured) + } +} + +func TestRequestIDMiddleware_PropagatesInbound(t *testing.T) { + var captured string + handler := RequestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = obs.RequestID(r.Context()) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-ID", "upstream-trace-abc") + handler.ServeHTTP(rec, req) + + if captured != "upstream-trace-abc" { + t.Fatalf("expected propagated ID, got %q", captured) + } + if rec.Header().Get("X-Request-ID") != "upstream-trace-abc" { + t.Fatalf("expected echoed header, got %q", rec.Header().Get("X-Request-ID")) + } +} + +func TestRequestIDMiddleware_RejectsOverlongInbound(t *testing.T) { + oversized := make([]byte, 200) + for i := range oversized { + oversized[i] = 'x' + } + + var captured string + handler := RequestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured = obs.RequestID(r.Context()) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-ID", string(oversized)) + handler.ServeHTTP(rec, req) + + if captured == string(oversized) { + t.Fatalf("expected oversized inbound ID to be replaced") + } + if captured == "" { + t.Fatalf("expected a fresh generated ID when inbound is rejected") + } +} diff --git a/internal/proxy/sanitize_test.go b/internal/proxy/sanitize_test.go new file mode 100644 index 0000000..584a25f --- /dev/null +++ b/internal/proxy/sanitize_test.go @@ -0,0 +1,36 @@ +package proxy + +import ( + "errors" + "strings" + "testing" +) + +func TestSanitizedUpstreamMessage_DoesNotLeakRawError(t *testing.T) { + // Any provider error we pass through must be collapsed into one of a + // small set of stable strings — never the raw upstream body. + tests := []struct { + name string + err error + want string + }{ + {"timeout", errors.New("anthropic request: context deadline exceeded"), "upstream timeout"}, + {"rate limited", errors.New("openai error 429: rate limit exceeded, retry in 1s"), "upstream rate limited"}, + {"credentials", errors.New("gemini error 401: invalid api key sk-leaked-key"), "upstream rejected credentials"}, + {"unavailable", errors.New("openai error 503: service unavailable"), "upstream unavailable"}, + {"generic", errors.New("anthropic error 500: {\"echo\":\"sensitive-data\"}"), "upstream error"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sanitizedUpstreamMessage(tt.err) + if got != tt.want { + t.Fatalf("sanitizedUpstreamMessage = %q, want %q", got, tt.want) + } + // Crucial: the raw error text must not appear in the sanitized + // output even for the generic bucket. + if strings.Contains(got, "sk-leaked-key") || strings.Contains(got, "echo") || strings.Contains(got, "sensitive-data") { + t.Fatalf("sanitized output leaked raw error: %q", got) + } + }) + } +} diff --git a/internal/proxy/streaming.go b/internal/proxy/streaming.go index cc554ae..edccff8 100644 --- a/internal/proxy/streaming.go +++ b/internal/proxy/streaming.go @@ -1,15 +1,20 @@ package proxy import ( + "context" "encoding/json" "fmt" "net/http" + "github.com/frugalsh/frugal/internal/obs" "github.com/frugalsh/frugal/internal/provider" ) -// streamResponse writes SSE chunks from a provider stream channel to the HTTP response. -func streamResponse(w http.ResponseWriter, ch <-chan provider.StreamChunk) error { +// streamResponseWithFirst writes a buffered first chunk followed by the rest +// of the stream. Splitting the first chunk out lets the handler retry to a +// fallback model on first-chunk failure before any bytes are sent to the +// client; once the first chunk lands, the stream is irrevocable. +func streamResponseWithFirst(ctx context.Context, w http.ResponseWriter, first *provider.StreamChunk, ch <-chan provider.StreamChunk) error { flusher, ok := w.(http.Flusher) if !ok { return fmt.Errorf("streaming not supported") @@ -21,16 +26,80 @@ func streamResponse(w http.ResponseWriter, ch <-chan provider.StreamChunk) error w.WriteHeader(http.StatusOK) flusher.Flush() + bytesWritten := false + if first != nil { + if err := writeSSEChunk(ctx, w, flusher, *first); err != nil { + return err + } + bytesWritten = true + if first.Done { + return nil + } + } + + gotTerminator := false for chunk := range ch { + if err := writeSSEChunk(ctx, w, flusher, chunk); err != nil { + return err + } + bytesWritten = true + if chunk.Done { + gotTerminator = true + return chunk.Err + } if chunk.Err != nil { - errData := fmt.Sprintf(`{"error":{"message":%q}}`, chunk.Err.Error()) - fmt.Fprintf(w, "data: %s\n\n", errData) + return chunk.Err + } + } + + // Channel closed. If the provider never sent a Done/Err terminator and + // we already shipped some bytes, the client would otherwise hang waiting + // for the next chunk. Emit an error frame so the client knows the stream + // was cut short, followed by the standard [DONE] marker so its SSE + // parser terminates cleanly. + if bytesWritten && !gotTerminator { + errFrame := `{"error":{"message":"upstream closed stream prematurely","type":"frugal_error"}}` + if _, err := fmt.Fprintf(w, "data: %s\n\n", errFrame); err != nil { + return err + } + flusher.Flush() + } + if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { + return err + } + flusher.Flush() + return nil +} + +// streamResponse is retained for tests that don't need the split first-chunk +// path. Prefer streamResponseWithFirst in production handler code. +func streamResponse(ctx context.Context, w http.ResponseWriter, ch <-chan provider.StreamChunk) error { + flusher, ok := w.(http.Flusher) + if !ok { + return fmt.Errorf("streaming not supported") + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + flusher.Flush() + + for chunk := range ch { + if chunk.Err != nil { + obs.L(ctx).Error("stream upstream error", "err", chunk.Err) + errData := fmt.Sprintf(`{"error":{"message":%q,"type":"frugal_error"}}`, sanitizedUpstreamMessage(chunk.Err)) + if _, err := fmt.Fprintf(w, "data: %s\n\n", errData); err != nil { + return err + } flusher.Flush() return chunk.Err } if chunk.Done { - fmt.Fprint(w, "data: [DONE]\n\n") + if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { + return err + } flusher.Flush() return nil } @@ -39,9 +108,48 @@ func streamResponse(w http.ResponseWriter, ch <-chan provider.StreamChunk) error if err != nil { continue } - fmt.Fprintf(w, "data: %s\n\n", data) + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + return err + } flusher.Flush() } + // Some providers close the stream channel without sending an explicit Done chunk. + // Emit the OpenAI-compatible terminator so clients don't wait indefinitely. + if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { + return err + } + flusher.Flush() + return nil +} + +// writeSSEChunk serializes a single StreamChunk to the wire. Error chunks +// collapse to a sanitized frame; Done triggers a [DONE] terminator; normal +// chunks emit a JSON event. +func writeSSEChunk(ctx context.Context, w http.ResponseWriter, flusher http.Flusher, chunk provider.StreamChunk) error { + if chunk.Err != nil { + obs.L(ctx).Error("stream upstream error", "err", chunk.Err) + errData := fmt.Sprintf(`{"error":{"message":%q,"type":"frugal_error"}}`, sanitizedUpstreamMessage(chunk.Err)) + if _, err := fmt.Fprintf(w, "data: %s\n\n", errData); err != nil { + return err + } + flusher.Flush() + return nil + } + if chunk.Done { + if _, err := fmt.Fprint(w, "data: [DONE]\n\n"); err != nil { + return err + } + flusher.Flush() + return nil + } + data, err := json.Marshal(chunk.Data) + if err != nil { + return nil + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + return err + } + flusher.Flush() return nil } diff --git a/internal/proxy/streaming_test.go b/internal/proxy/streaming_test.go new file mode 100644 index 0000000..c5c6ad3 --- /dev/null +++ b/internal/proxy/streaming_test.go @@ -0,0 +1,38 @@ +package proxy + +import ( + "context" + "strings" + "testing" + + "net/http/httptest" + + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/types" +) + +func TestStreamResponse_EmitsDoneWhenChannelClosesWithoutDoneChunk(t *testing.T) { + w := httptest.NewRecorder() + ch := make(chan provider.StreamChunk, 1) + + ch <- provider.StreamChunk{ + Data: &types.ChatCompletionChunk{ + ID: "chatcmpl-stream", + Object: "chat.completion.chunk", + Model: "mock-model", + Choices: []types.ChunkChoice{ + {Index: 0, Delta: types.MessageDelta{Content: "hello"}}, + }, + }, + } + close(ch) + + if err := streamResponse(context.Background(), w, ch); err != nil { + t.Fatalf("streamResponse returned error: %v", err) + } + + body := w.Body.String() + if !strings.Contains(body, "data: [DONE]") { + t.Fatalf("expected [DONE] terminator, got body: %s", body) + } +} diff --git a/internal/proxy/usecase_test.go b/internal/proxy/usecase_test.go new file mode 100644 index 0000000..53bb5ba --- /dev/null +++ b/internal/proxy/usecase_test.go @@ -0,0 +1,347 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "github.com/frugalsh/frugal/internal/classifier" + "github.com/frugalsh/frugal/internal/provider" + "github.com/frugalsh/frugal/internal/router" + "github.com/frugalsh/frugal/internal/types" + "github.com/frugalsh/frugal/internal/usecase" +) + +// setupUseCaseHandler builds a handler wired with two mock models plus a +// small use-case registry that routes specific tiers to each. Chi is used +// here (unlike setupHandler) so the /v1/bundles/{useCase} path param is +// parsed by the actual production router. +func setupUseCaseHandler(t *testing.T) (*Handler, *httptest.Server) { + t.Helper() + + reg := provider.NewRegistry() + reg.Register(&mockProvider{ + name: "mock", + models: []string{"mock-cheap", "mock-premium"}, + }) + + models := []router.ModelEntry{ + {Name: "mock-cheap", Provider: "mock", CostPer1KInput: 0.0001, CostPer1KOutput: 0.0004, + Reasoning: 0.7, Coding: 0.68, Creative: 0.65, InstructFollowing: 0.72, MaxContext: 128000}, + {Name: "mock-premium", Provider: "mock", CostPer1KInput: 0.003, CostPer1KOutput: 0.015, + Reasoning: 0.95, Coding: 0.93, Creative: 0.9, InstructFollowing: 0.95, MaxContext: 200000}, + } + thresholds := map[string]router.Threshold{ + "high": {MinReasoning: 0.88, MinCoding: 0.85, MinCreative: 0.82, MinInstructFollowing: 0.88}, + "balanced": {MinReasoning: 0.70, MinCoding: 0.68, MinCreative: 0.65, MinInstructFollowing: 0.72}, + "cost": {}, + } + + // Build a tiny in-tmpdir use-case registry: two use cases, known tiers. + dir := t.TempDir() + writeYAML(t, filepath.Join(dir, "heavy-lift.yaml"), ` +id: heavy-lift +description: uses the premium model +source: curated +as_of: "2026-04-21" +confidence: high +bundles: + high: { chat: mock-premium } + balanced: { chat: mock-premium } + cost: { chat: mock-cheap } +`) + writeYAML(t, filepath.Join(dir, "tight-budget.yaml"), ` +id: tight-budget +description: always cheap +source: curated +as_of: "2026-04-21" +confidence: high +bundles: + high: { chat: mock-cheap } + balanced: { chat: mock-cheap } + cost: { chat: mock-cheap } +`) + useCases, err := usecase.Load(dir) + if err != nil { + t.Fatalf("Load use cases: %v", err) + } + + cls := classifier.NewRuleBased() + rtr := router.New(models, thresholds) + h := NewHandlerWithUseCases(cls, rtr, reg, useCases) + + r := chi.NewRouter() + r.Use(HeaderExtractionMiddleware) + r.Post("/v1/chat/completions", h.ChatCompletions) + r.Get("/v1/bundles", h.ListBundles) + r.Get("/v1/bundles/{useCase}", h.GetBundle) + + ts := httptest.NewServer(r) + return h, ts +} + +func writeYAML(t *testing.T, path, body string) { + t.Helper() + if err := os.WriteFile(path, []byte(body), 0o644); err != nil { + t.Fatal(err) + } +} + +func TestChatCompletions_UseCaseRoutesToBundleModel(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("hi")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Use-Case", "heavy-lift") + req.Header.Set("X-Frugal-Quality", "balanced") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("status=%d body=%s", resp.StatusCode, string(b)) + } + if got := resp.Header.Get("X-Frugal-Model"); got != "mock-premium" { + t.Errorf("expected X-Frugal-Model=mock-premium, got %q", got) + } + if got := resp.Header.Get("X-Frugal-Use-Case"); got != "heavy-lift" { + t.Errorf("expected X-Frugal-Use-Case echo, got %q", got) + } +} + +func TestChatCompletions_UseCaseQualityTiers(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + cases := []struct { + tier, wantModel string + }{ + {"high", "mock-premium"}, + {"balanced", "mock-premium"}, + {"cost", "mock-cheap"}, + } + for _, tc := range cases { + t.Run(tc.tier, func(t *testing.T) { + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("hi")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Use-Case", "heavy-lift") + req.Header.Set("X-Frugal-Quality", tc.tier) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if got := resp.Header.Get("X-Frugal-Model"); got != tc.wantModel { + t.Errorf("tier=%s: expected %s, got %s", tc.tier, tc.wantModel, got) + } + }) + } +} + +func TestChatCompletions_UnknownUseCaseReturns400(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("hi")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Use-Case", "legal-research") // not registered + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + b, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(b), "unknown use case") { + t.Errorf("expected 'unknown use case' in body, got %s", string(b)) + } + if !strings.Contains(string(b), "heavy-lift") { + t.Errorf("expected body to list known use cases including heavy-lift, got %s", string(b)) + } +} + +func TestChatCompletions_AbsentUseCaseHeaderFallsThrough(t *testing.T) { + // Absent header → existing classifier/router path runs. Default quality + // is balanced; mock-cheap and mock-premium both clear it, so the router + // picks cheapest (mock-cheap). + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("hi")}}, + }) + resp, err := http.Post(ts.URL+"/v1/chat/completions", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if got := resp.Header.Get("X-Frugal-Model"); got == "" { + t.Errorf("expected routed model on fallthrough, got empty") + } + if got := resp.Header.Get("X-Frugal-Use-Case"); got != "" { + t.Errorf("expected no use-case echo when header absent, got %q", got) + } +} + +func TestGetBundle_ReturnsJSONForKnownUseCase(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/bundles/heavy-lift?quality=balanced") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + var got map[string]any + if err := json.NewDecoder(resp.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if got["use_case"] != "heavy-lift" { + t.Errorf("wrong use_case: %v", got["use_case"]) + } + if got["chat"] != "mock-premium" { + t.Errorf("wrong chat: %v", got["chat"]) + } + if got["quality"] != "balanced" { + t.Errorf("wrong quality: %v", got["quality"]) + } + if got["source"] != "curated" { + t.Errorf("wrong source: %v", got["source"]) + } +} + +func TestGetBundle_DefaultsToBalanced(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/bundles/heavy-lift") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + var got map[string]any + _ = json.NewDecoder(resp.Body).Decode(&got) + if got["quality"] != "balanced" { + t.Errorf("default tier should be balanced, got %v", got["quality"]) + } +} + +func TestGetBundle_UnknownUseCaseReturns404(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/bundles/legal-research") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } +} + +func TestListBundles_ReturnsEveryRegisteredUseCase(t *testing.T) { + _, ts := setupUseCaseHandler(t) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/v1/bundles") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + var got map[string][]map[string]any + if err := json.NewDecoder(resp.Body).Decode(&got); err != nil { + t.Fatalf("decode: %v", err) + } + if len(got["data"]) != 2 { + t.Fatalf("expected 2 entries, got %d", len(got["data"])) + } + // Sorted by ID, so heavy-lift < tight-budget. + if got["data"][0]["id"] != "heavy-lift" { + t.Errorf("unexpected order: %v", got["data"][0]["id"]) + } + if got["data"][1]["id"] != "tight-budget" { + t.Errorf("unexpected order: %v", got["data"][1]["id"]) + } +} + +func TestUseCase_NoRegistryConfiguredReturns400ForChatHeader(t *testing.T) { + reg := provider.NewRegistry() + reg.Register(&mockProvider{name: "mock", models: []string{"mock-cheap"}}) + models := []router.ModelEntry{{Name: "mock-cheap", Provider: "mock", + CostPer1KInput: 0.0001, CostPer1KOutput: 0.0004, + Reasoning: 0.9, Coding: 0.9, Creative: 0.9, InstructFollowing: 0.9, MaxContext: 128000}} + thresholds := map[string]router.Threshold{"balanced": {}} + h := NewHandler(classifier.NewRuleBased(), router.New(models, thresholds), reg) + + r := chi.NewRouter() + r.Use(HeaderExtractionMiddleware) + r.Post("/v1/chat/completions", h.ChatCompletions) + ts := httptest.NewServer(r) + defer ts.Close() + + body, _ := json.Marshal(types.ChatCompletionRequest{ + Model: "auto", + Messages: []types.Message{{Role: "user", Content: mustMarshalJSON("hi")}}, + }) + req, _ := http.NewRequest("POST", ts.URL+"/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Frugal-Use-Case", "heavy-lift") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + b, _ := io.ReadAll(resp.Body) + t.Fatalf("expected 400 when use-case header is set but no registry, got %d: %s", resp.StatusCode, string(b)) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 7d8f18a..6a27a00 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -19,7 +19,7 @@ func New(models []ModelEntry, thresholds map[string]Threshold) *Router { // Route selects a model based on query features and quality threshold. func (r *Router) Route(features types.QueryFeatures, quality types.QualityThreshold, fallbacks []string) types.RoutingDecision { - threshold := r.thresholds[string(quality)] + threshold := r.thresholdForQuality(quality) // Filter candidates var candidates []ModelEntry @@ -30,31 +30,42 @@ func (r *Router) Route(features types.QueryFeatures, quality types.QualityThresh candidates = append(candidates, m) } + var relaxedFrom string if len(candidates) == 0 { - // Fallback: relax threshold to balanced, then cost - for _, fallbackQuality := range []string{"balanced", "cost"} { - ft := r.thresholds[fallbackQuality] + // Fallback: relax threshold to balanced, then cost. Record the + // original quality so callers can surface a degraded-routing signal. + for _, fallbackQuality := range []types.QualityThreshold{types.QualityBalanced, types.QualityCost} { + if fallbackQuality == quality { + continue + } + ft := r.thresholdForQuality(fallbackQuality) for _, m := range r.models { if r.meetsRequirements(m, features, ft) { candidates = append(candidates, m) } } if len(candidates) > 0 { + relaxedFrom = string(quality) break } } } if len(candidates) == 0 { - // Last resort: pick the cheapest model that has hard requirements + // Last resort: pick the cheapest model that satisfies only the hard + // requirements. This is strictly weaker than any threshold, so + // RelaxedFrom is set if it wasn't already. candidates = r.filterHardRequirements(features) + if len(candidates) > 0 && relaxedFrom == "" { + relaxedFrom = string(quality) + } } if len(candidates) == 0 { return types.RoutingDecision{ - Quality: string(quality), + Quality: string(quality), Features: features, - Reason: "no models available", + Reason: "no models available", } } @@ -68,6 +79,7 @@ func (r *Router) Route(features types.QueryFeatures, quality types.QualityThresh SelectedModel: selected.Name, SelectedProvider: selected.Provider, Quality: string(quality), + RelaxedFrom: relaxedFrom, Features: features, Candidates: len(candidates), Reason: r.buildReason(selected, features, quality), @@ -76,6 +88,16 @@ func (r *Router) Route(features types.QueryFeatures, quality types.QualityThresh } } +func (r *Router) thresholdForQuality(quality types.QualityThreshold) Threshold { + if t, ok := r.thresholds[string(quality)]; ok { + return t + } + if t, ok := r.thresholds[string(types.QualityBalanced)]; ok { + return t + } + return Threshold{} +} + func (r *Router) meetsRequirements(m ModelEntry, f types.QueryFeatures, t Threshold) bool { // Hard requirements if f.RequiresToolUse && !m.ToolUse { @@ -84,6 +106,15 @@ func (r *Router) meetsRequirements(m ModelEntry, f types.QueryFeatures, t Thresh if f.RequiresJSON && !m.JSONMode { return false } + if f.RequiresVision && !m.Vision { + return false + } + if f.RequiresMultipleCompletions && m.Provider == "anthropic" { + // Anthropic's Messages API does not support N > 1. Rather than + // silently return a single completion, drop Anthropic from + // candidate set and let the router pick a provider that honors it. + return false + } if f.EstimatedInputTokens > m.MaxContext { return false } @@ -104,6 +135,12 @@ func (r *Router) filterHardRequirements(f types.QueryFeatures) []ModelEntry { if f.RequiresJSON && !m.JSONMode { continue } + if f.RequiresVision && !m.Vision { + continue + } + if f.RequiresMultipleCompletions && m.Provider == "anthropic" { + continue + } if f.EstimatedInputTokens > m.MaxContext { continue } diff --git a/internal/router/router_test.go b/internal/router/router_test.go index 99615b5..f746007 100644 --- a/internal/router/router_test.go +++ b/internal/router/router_test.go @@ -162,3 +162,35 @@ func TestRoute_NoModels_ReturnsEmpty(t *testing.T) { t.Errorf("expected empty model when no models available, got %s", d.SelectedModel) } } + +func TestRoute_UnknownQualityDefaultsToBalancedThreshold(t *testing.T) { + r := New(testModels(), testThresholds()) + + features := types.QueryFeatures{ + EstimatedInputTokens: 100, + EstimatedOutputTokens: 100, + } + + d := r.Route(features, types.QualityThreshold("unknown"), nil) + + if d.SelectedModel != "mid-model" { + t.Errorf("expected balanced fallback model mid-model for unknown quality, got %s", d.SelectedModel) + } +} + +func TestRoute_MissingRequestedThresholdFallsBackToBalanced(t *testing.T) { + thresholds := testThresholds() + delete(thresholds, "high") + r := New(testModels(), thresholds) + + features := types.QueryFeatures{ + EstimatedInputTokens: 100, + EstimatedOutputTokens: 100, + } + + d := r.Route(features, types.QualityHigh, nil) + + if d.SelectedModel != "mid-model" { + t.Errorf("expected balanced fallback model mid-model when high threshold missing, got %s", d.SelectedModel) + } +} diff --git a/internal/router/taxonomy.go b/internal/router/taxonomy.go index c66e081..f8ee05a 100644 --- a/internal/router/taxonomy.go +++ b/internal/router/taxonomy.go @@ -4,17 +4,18 @@ import "github.com/frugalsh/frugal/internal/config" // ModelEntry is a routing-friendly view of a model from the config. type ModelEntry struct { - Name string - Provider string - CostPer1KInput float64 - CostPer1KOutput float64 - Reasoning float64 - Coding float64 - Creative float64 + Name string + Provider string + CostPer1KInput float64 + CostPer1KOutput float64 + Reasoning float64 + Coding float64 + Creative float64 InstructFollowing float64 - ToolUse bool - JSONMode bool - MaxContext int + ToolUse bool + JSONMode bool + Vision bool + MaxContext int } // Threshold holds the minimum capability scores for a quality level. @@ -42,6 +43,7 @@ func BuildTaxonomy(cfg *config.Config) ([]ModelEntry, map[string]Threshold) { InstructFollowing: mc.Capabilities.InstructionFollowing, ToolUse: mc.Capabilities.ToolUse, JSONMode: mc.Capabilities.JSONMode, + Vision: mc.Capabilities.Vision, MaxContext: mc.Capabilities.MaxContext, }) } diff --git a/internal/sync/sync.go b/internal/sync/sync.go index c46e4e5..a7cb694 100644 --- a/internal/sync/sync.go +++ b/internal/sync/sync.go @@ -1,12 +1,17 @@ package sync import ( + "context" "encoding/json" "fmt" "net/http" + "time" ) -const modelsDevURL = "https://models.dev/api.json" +const ( + modelsDevURL = "https://models.dev/api.json" + fetchTimeout = 5 * time.Second +) // ModelsDevProvider represents a provider entry from the models.dev API. type ModelsDevProvider struct { @@ -48,10 +53,20 @@ type Modalities struct { Output []string `json:"output"` } -// FetchModels fetches the full model catalog from models.dev. +// FetchModels fetches the full model catalog from models.dev. It enforces a +// bounded timeout so a hung models.dev never blocks Frugal startup. // Returns a flat map of "provider/model" → entry, plus bare model names. -func FetchModels() (map[string]ModelsDevEntry, error) { - resp, err := http.Get(modelsDevURL) +func FetchModels(ctx context.Context) (map[string]ModelsDevEntry, error) { + ctx, cancel := context.WithTimeout(ctx, fetchTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelsDevURL, nil) + if err != nil { + return nil, fmt.Errorf("building models.dev request: %w", err) + } + + client := &http.Client{Timeout: fetchTimeout} + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("fetching models.dev: %w", err) } diff --git a/internal/types/openai.go b/internal/types/openai.go index 9b110de..cfe98bf 100644 --- a/internal/types/openai.go +++ b/internal/types/openai.go @@ -1,23 +1,45 @@ package types -import "encoding/json" +import ( + "encoding/json" + "strings" +) // ChatCompletionRequest is the OpenAI-compatible inbound request. +// +// Fields not used by Frugal's classifier or router are still decoded and +// forwarded verbatim to the upstream OpenAI provider so that clients using +// newer SDK features ("no code changes") continue to work. Anthropic and +// Google translators only read the subset of fields they can map. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - N *int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop json.RawMessage `json:"stop,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - User string `json:"user,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + StreamOptions json.RawMessage `json:"stream_options,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias json.RawMessage `json:"logit_bias,omitempty"` + LogProbs *bool `json:"logprobs,omitempty"` + TopLogprobs *int `json:"top_logprobs,omitempty"` + Seed *int `json:"seed,omitempty"` + User string `json:"user,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + ReasoningEffort *string `json:"reasoning_effort,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + Store *bool `json:"store,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction json.RawMessage `json:"prediction,omitempty"` + Audio json.RawMessage `json:"audio,omitempty"` } type Message struct { @@ -28,17 +50,71 @@ type Message struct { ToolCallID string `json:"tool_call_id,omitempty"` } -// ContentString extracts a plain string from the Content field. -// Returns empty string if Content is not a simple string. -func (m *Message) ContentString() string { +// ContentPart mirrors an OpenAI message content element. Content may be either +// a plain string (classic) or an array of parts (multimodal, tool results). +// Unknown part fields are preserved via the Raw field so translators can +// forward them without loss. CacheControl holds an Anthropic-style hint +// ({"type":"ephemeral"}) that the Anthropic translator forwards verbatim. +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` + InputAudio json.RawMessage `json:"input_audio,omitempty"` + CacheControl json.RawMessage `json:"cache_control,omitempty"` + Raw json.RawMessage `json:"-"` +} + +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +// ContentParts normalizes Content into a slice of parts. A string Content is +// returned as a single text part. Returns nil if Content is absent or opaque +// (e.g. structured tool-result objects that aren't arrays) — callers should +// fall back to the raw bytes in that case. +func (m *Message) ContentParts() []ContentPart { if len(m.Content) == 0 { - return "" + return nil } var s string if err := json.Unmarshal(m.Content, &s); err == nil { - return s + return []ContentPart{{Type: "text", Text: s}} + } + var parts []ContentPart + if err := json.Unmarshal(m.Content, &parts); err == nil { + return parts + } + return nil +} + +// ContentText joins the text of every text part in Content. Non-text parts +// (images, audio) are skipped. This is the right accessor for classifier +// feature extraction and for providers that do not support multimodal input. +func (m *Message) ContentText() string { + parts := m.ContentParts() + if len(parts) == 0 { + return "" + } + var b strings.Builder + for _, p := range parts { + if p.Type == "text" || p.Type == "" { + b.WriteString(p.Text) + } + } + return b.String() +} + +// HasNonTextContent reports whether Content includes any part that is not +// plain text (image_url, input_audio, tool results). Used by the router to +// exclude non-vision models. +func (m *Message) HasNonTextContent() bool { + for _, p := range m.ContentParts() { + if p.Type != "" && p.Type != "text" { + return true + } } - return "" + return false } type ResponseFormat struct { diff --git a/internal/types/openai_test.go b/internal/types/openai_test.go new file mode 100644 index 0000000..c161d91 --- /dev/null +++ b/internal/types/openai_test.go @@ -0,0 +1,57 @@ +package types + +import ( + "encoding/json" + "testing" +) + +func TestContentText_StringContent(t *testing.T) { + m := Message{Role: "user", Content: json.RawMessage(`"hello world"`)} + if got := m.ContentText(); got != "hello world" { + t.Fatalf("ContentText = %q, want %q", got, "hello world") + } + if m.HasNonTextContent() { + t.Fatalf("HasNonTextContent = true for string content") + } +} + +func TestContentText_ArrayContent_JoinsTextParts(t *testing.T) { + m := Message{ + Role: "user", + Content: json.RawMessage(`[ + {"type":"text","text":"describe this"}, + {"type":"image_url","image_url":{"url":"data:image/png;base64,AAAA"}}, + {"type":"text","text":" please"} + ]`), + } + if got := m.ContentText(); got != "describe this please" { + t.Fatalf("ContentText = %q, want %q", got, "describe this please") + } + if !m.HasNonTextContent() { + t.Fatalf("HasNonTextContent = false; expected true for array with image part") + } +} + +func TestContentParts_PreservesImageURL(t *testing.T) { + m := Message{ + Role: "user", + Content: json.RawMessage(`[{"type":"image_url","image_url":{"url":"https://example.com/i.png","detail":"low"}}]`), + } + parts := m.ContentParts() + if len(parts) != 1 || parts[0].ImageURL == nil { + t.Fatalf("expected 1 image part with ImageURL set, got %#v", parts) + } + if parts[0].ImageURL.URL != "https://example.com/i.png" || parts[0].ImageURL.Detail != "low" { + t.Fatalf("unexpected image URL: %+v", parts[0].ImageURL) + } +} + +func TestContentText_EmptyContent(t *testing.T) { + m := Message{Role: "user"} + if got := m.ContentText(); got != "" { + t.Fatalf("ContentText = %q, want empty", got) + } + if m.HasNonTextContent() { + t.Fatalf("HasNonTextContent = true for empty content") + } +} diff --git a/internal/types/routing.go b/internal/types/routing.go index bb815c9..4dd148d 100644 --- a/internal/types/routing.go +++ b/internal/types/routing.go @@ -1,5 +1,7 @@ package types +import "strings" + // QualityThreshold controls how aggressively Frugal routes to cheaper models. type QualityThreshold string @@ -9,31 +11,38 @@ const ( QualityCost QualityThreshold = "cost" ) -// ParseQualityThreshold parses a string into a QualityThreshold, defaulting to balanced. -func ParseQualityThreshold(s string) QualityThreshold { - switch s { +// ParseQualityThreshold parses a string into a QualityThreshold. It recognises +// high/balanced/cost (case- and whitespace-insensitive). Returns (value, true) +// on a known value and (QualityBalanced, false) on anything else so callers +// can distinguish "client sent a typo" from "client omitted the header". +func ParseQualityThreshold(s string) (QualityThreshold, bool) { + switch strings.ToLower(strings.TrimSpace(s)) { case "high": - return QualityHigh + return QualityHigh, true + case "balanced": + return QualityBalanced, true case "cost": - return QualityCost + return QualityCost, true default: - return QualityBalanced + return QualityBalanced, false } } // QueryFeatures is the output of the classifier's feature extraction. type QueryFeatures struct { - EstimatedInputTokens int `json:"estimated_input_tokens"` - EstimatedOutputTokens int `json:"estimated_output_tokens"` - HasCode bool `json:"has_code"` - HasMath bool `json:"has_math"` - HasSystemPrompt bool `json:"has_system_prompt"` - SystemPromptLength int `json:"system_prompt_length"` - ConversationTurns int `json:"conversation_turns"` - RequiresJSON bool `json:"requires_json"` - RequiresToolUse bool `json:"requires_tool_use"` - DomainHints []string `json:"domain_hints"` - ComplexityScore float64 `json:"complexity_score"` // 0.0 - 1.0 + EstimatedInputTokens int `json:"estimated_input_tokens"` + EstimatedOutputTokens int `json:"estimated_output_tokens"` + HasCode bool `json:"has_code"` + HasMath bool `json:"has_math"` + HasSystemPrompt bool `json:"has_system_prompt"` + SystemPromptLength int `json:"system_prompt_length"` + ConversationTurns int `json:"conversation_turns"` + RequiresJSON bool `json:"requires_json"` + RequiresToolUse bool `json:"requires_tool_use"` + RequiresVision bool `json:"requires_vision"` + RequiresMultipleCompletions bool `json:"requires_multiple_completions"` + DomainHints []string `json:"domain_hints"` + ComplexityScore float64 `json:"complexity_score"` // 0.0 - 1.0 } // RoutingDecision captures why a particular model was chosen. @@ -41,6 +50,10 @@ type RoutingDecision struct { SelectedModel string `json:"selected_model"` SelectedProvider string `json:"selected_provider"` Quality string `json:"quality_threshold"` + // RelaxedFrom is set when no model met the requested threshold and the + // router fell through to a lower tier. Clients can alarm on it instead + // of treating degraded routing as a clean win. + RelaxedFrom string `json:"relaxed_from,omitempty"` Features QueryFeatures `json:"features"` Candidates int `json:"candidates_considered"` Reason string `json:"reason"` diff --git a/internal/types/routing_test.go b/internal/types/routing_test.go new file mode 100644 index 0000000..6417a85 --- /dev/null +++ b/internal/types/routing_test.go @@ -0,0 +1,29 @@ +package types + +import "testing" + +func TestParseQualityThreshold_NormalizedInputs(t *testing.T) { + tests := []struct { + name string + in string + want QualityThreshold + wantOK bool + }{ + {name: "exact high", in: "high", want: QualityHigh, wantOK: true}, + {name: "uppercase high", in: "HIGH", want: QualityHigh, wantOK: true}, + {name: "spaced high", in: " high ", want: QualityHigh, wantOK: true}, + {name: "mixed-case cost", in: "CoSt", want: QualityCost, wantOK: true}, + {name: "spaced balanced", in: " balanced ", want: QualityBalanced, wantOK: true}, + {name: "unknown returns not-ok", in: "fast", want: QualityBalanced, wantOK: false}, + {name: "empty returns not-ok", in: "", want: QualityBalanced, wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := ParseQualityThreshold(tt.in) + if got != tt.want || ok != tt.wantOK { + t.Fatalf("ParseQualityThreshold(%q) = (%q, %v), want (%q, %v)", tt.in, got, ok, tt.want, tt.wantOK) + } + }) + } +} diff --git a/internal/usecase/usecase.go b/internal/usecase/usecase.go new file mode 100644 index 0000000..9252206 --- /dev/null +++ b/internal/usecase/usecase.go @@ -0,0 +1,164 @@ +// Package usecase loads named use cases (research-synthesis, code-dev, …) +// and the (capability → model) bundles the router picks per quality tier. +// +// Use cases are Frugal's primary product abstraction: callers declare +// what kind of work they're doing, Frugal delivers the bundle that's +// proven best for that work. See the project vision in +// .claude/projects/.../memory/project_vision.md for the full rationale. +// +// This package is capability-agnostic — it stores `search` and `rerank` +// fields on Bundle but doesn't care whether those capabilities have +// providers wired yet. Ring 1a ships chat-only; Rings 1b/1c populate +// the other fields without touching this package. +package usecase + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "gopkg.in/yaml.v3" +) + +// Bundle is the recommended mapping of capability → model for one +// (use case, quality tier) pair. Fields not yet populated by curation +// for a given tier are left zero-valued; consumers should check for +// empty strings rather than treat them as "no opinion." +type Bundle struct { + Chat string `yaml:"chat"` + Search string `yaml:"search,omitempty"` + Rerank string `yaml:"rerank,omitempty"` + Reason string `yaml:"reason,omitempty"` +} + +// UseCase is the full record loaded from one YAML file. +type UseCase struct { + ID string `yaml:"id"` + Description string `yaml:"description"` + Source string `yaml:"source"` + AsOf string `yaml:"as_of"` + Confidence string `yaml:"confidence"` + Bundles map[string]Bundle `yaml:"bundles"` + Workload string `yaml:"workload,omitempty"` +} + +// ValidTiers are the quality tiers a bundle must define. Missing a tier +// isn't fatal at load time — callers see an empty Bundle and fall through +// to the non-use-case routing path — but we warn during Load. +var ValidTiers = []string{"high", "balanced", "cost"} + +// Registry is a read-only lookup table of known use cases. Construct via +// Load; zero values are invalid. +type Registry struct { + mu sync.RWMutex + cases map[string]UseCase +} + +// Load reads every *.yaml file in dir, parses it as a UseCase, and +// returns a populated Registry. An empty dir (no files) returns an empty +// registry and no error — use-case routing is opt-in, and running +// without any use cases is a valid configuration. +func Load(dir string) (*Registry, error) { + r := &Registry{cases: map[string]UseCase{}} + if dir == "" { + return r, nil + } + entries, err := os.ReadDir(dir) + if err != nil { + if os.IsNotExist(err) { + return r, nil + } + return nil, fmt.Errorf("usecase: read dir %q: %w", dir, err) + } + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !strings.HasSuffix(name, ".yaml") && !strings.HasSuffix(name, ".yml") { + continue + } + path := filepath.Join(dir, name) + uc, err := loadFile(path) + if err != nil { + return nil, err + } + if _, dup := r.cases[uc.ID]; dup { + return nil, fmt.Errorf("usecase: duplicate id %q (in %s)", uc.ID, path) + } + r.cases[uc.ID] = uc + } + return r, nil +} + +func loadFile(path string) (UseCase, error) { + data, err := os.ReadFile(path) + if err != nil { + return UseCase{}, fmt.Errorf("usecase: read %q: %w", path, err) + } + var uc UseCase + dec := yaml.NewDecoder(strings.NewReader(string(data))) + dec.KnownFields(true) + if err := dec.Decode(&uc); err != nil { + return UseCase{}, fmt.Errorf("usecase: parse %q: %w", path, err) + } + if uc.ID == "" { + return UseCase{}, fmt.Errorf("usecase: %q: missing id", path) + } + if len(uc.Bundles) == 0 { + return UseCase{}, fmt.Errorf("usecase: %q: no bundles declared", path) + } + for _, tier := range ValidTiers { + b, ok := uc.Bundles[tier] + if !ok { + return UseCase{}, fmt.Errorf("usecase: %q: missing %q tier", path, tier) + } + if b.Chat == "" { + return UseCase{}, fmt.Errorf("usecase: %q: tier %q has empty chat model", path, tier) + } + } + return uc, nil +} + +// Get returns the full UseCase record for id. Second return is false +// when id is unknown. +func (r *Registry) Get(id string) (UseCase, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + uc, ok := r.cases[id] + return uc, ok +} + +// Bundle returns the bundle for (id, tier). Second return is false when +// the use case is unknown OR the tier isn't defined for that use case. +func (r *Registry) Bundle(id, tier string) (Bundle, bool) { + uc, ok := r.Get(id) + if !ok { + return Bundle{}, false + } + b, ok := uc.Bundles[tier] + return b, ok +} + +// IDs returns a sorted list of registered use-case ids. Useful for the +// "unknown use case" 400 response and /v1/bundles index rendering. +func (r *Registry) IDs() []string { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]string, 0, len(r.cases)) + for id := range r.cases { + out = append(out, id) + } + sort.Strings(out) + return out +} + +// Len reports how many use cases are loaded. Zero is a valid state. +func (r *Registry) Len() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.cases) +} diff --git a/internal/usecase/usecase_test.go b/internal/usecase/usecase_test.go new file mode 100644 index 0000000..9f06a2f --- /dev/null +++ b/internal/usecase/usecase_test.go @@ -0,0 +1,163 @@ +package usecase + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoad_StarterSetRegisters(t *testing.T) { + // Resolve the in-tree starter directory so the test proves the shipped + // config loads cleanly. + dir := filepath.Join("..", "..", "config", "use_cases") + if _, err := os.Stat(dir); err != nil { + t.Skipf("starter use_cases dir not present: %v", err) + } + + r, err := Load(dir) + if err != nil { + t.Fatalf("Load: %v", err) + } + + want := []string{"code-dev", "factual-qa", "research-synthesis", "structured-extraction"} + got := r.IDs() + if len(got) != len(want) { + t.Fatalf("expected %d use cases, got %d (%v)", len(want), len(got), got) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("ids[%d]: got %q want %q", i, got[i], want[i]) + } + } + + // Every starter case must declare all three tiers with a chat model. + for _, id := range want { + uc, ok := r.Get(id) + if !ok { + t.Fatalf("expected %q to be registered", id) + } + for _, tier := range ValidTiers { + b, ok := uc.Bundles[tier] + if !ok { + t.Errorf("%s: missing tier %q", id, tier) + continue + } + if b.Chat == "" { + t.Errorf("%s@%s: empty chat model", id, tier) + } + } + } +} + +func TestBundle_UnknownUseCaseReturnsFalse(t *testing.T) { + r, err := Load(filepath.Join("..", "..", "config", "use_cases")) + if err != nil { + t.Skipf("starter dir not loadable: %v", err) + } + if _, ok := r.Bundle("does-not-exist", "balanced"); ok { + t.Fatalf("expected false for unknown use case") + } +} + +func TestBundle_UnknownTierReturnsFalse(t *testing.T) { + r, err := Load(filepath.Join("..", "..", "config", "use_cases")) + if err != nil { + t.Skipf("starter dir not loadable: %v", err) + } + if _, ok := r.Bundle("factual-qa", "premium"); ok { + t.Fatalf("expected false for unknown tier") + } +} + +func TestLoad_EmptyDirNoError(t *testing.T) { + // Running without any use-case configs should produce an empty registry, + // not an error — use-case routing is opt-in. + tmp := t.TempDir() + r, err := Load(tmp) + if err != nil { + t.Fatalf("Load empty dir: %v", err) + } + if r.Len() != 0 { + t.Fatalf("expected empty registry, got %d entries", r.Len()) + } +} + +func TestLoad_NonexistentDirNoError(t *testing.T) { + r, err := Load("/this/does/not/exist/anywhere") + if err != nil { + t.Fatalf("expected no error for missing dir, got %v", err) + } + if r.Len() != 0 { + t.Fatalf("expected empty registry, got %d entries", r.Len()) + } +} + +func TestLoad_RejectsMissingTier(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.yaml") + if err := os.WriteFile(path, []byte(` +id: bad +description: only-balanced-tier +source: curated +as_of: "2026-04-21" +confidence: low +bundles: + balanced: + chat: gpt-4o-mini +`), 0o644); err != nil { + t.Fatal(err) + } + + if _, err := Load(dir); err == nil { + t.Fatalf("expected error when required tier is missing") + } +} + +func TestLoad_RejectsEmptyChatModel(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.yaml") + if err := os.WriteFile(path, []byte(` +id: bad +description: empty-chat +source: curated +as_of: "2026-04-21" +confidence: low +bundles: + high: + chat: "" + balanced: + chat: gpt-4o-mini + cost: + chat: gpt-4o-mini +`), 0o644); err != nil { + t.Fatal(err) + } + + if _, err := Load(dir); err == nil { + t.Fatalf("expected error when chat model is empty at any tier") + } +} + +func TestLoad_RejectsDuplicateID(t *testing.T) { + dir := t.TempDir() + body := []byte(` +id: dup +description: d +source: curated +as_of: "2026-04-21" +confidence: low +bundles: + high: { chat: gpt-4o-mini } + balanced:{ chat: gpt-4o-mini } + cost: { chat: gpt-4o-mini } +`) + if err := os.WriteFile(filepath.Join(dir, "a.yaml"), body, 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "b.yaml"), body, 0o644); err != nil { + t.Fatal(err) + } + if _, err := Load(dir); err == nil { + t.Fatalf("expected error on duplicate ids across files") + } +}