diff --git a/.env.example b/.env.example index e31ccaa..aabbfe3 100644 --- a/.env.example +++ b/.env.example @@ -1,13 +1,24 @@ -# amp-proxy configuration +# amp-proxy configuration (environment variables) +# These override values from the config file. # Listening LISTEN_PORT=18317 -LISTEN_ADDR=0.0.0.0 +LISTEN_ADDR=127.0.0.1 -# Upstream targets +# Upstream AMPCODE_URL=https://ampcode.com -VIBEPROXY_URL=http://localhost:8317 + +# Config file path (optional, auto-detected if not set) +# AMP_PROXY_CONFIG= + +# Debug logging +# AMP_PROXY_DEBUG=1 # Exa API for web_search and read_web_page (optional) # Get your key at https://dashboard.exa.ai/api-keys EXA_API_KEY= + +# Direct API keys (optional, skip OAuth) +# ANTHROPIC_API_KEY= +# OPENAI_API_KEY= +# GEMINI_API_KEY= diff --git a/.gitignore b/.gitignore index f029fa1..1401b7a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ bin/ dist/ amp-proxy +amp-proxy-v2 *.exe *.dll *.so diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..af612d9 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,351 @@ +# amp-proxy v2: Eliminating vibeproxy — Revised Architecture + +## Flaws Found in v1 Plan & Corrections + +### Flaw 1: Over-engineered config for a local dev tool +The original YAML schema was ~60 lines with sections for thinking, gateway, cliproxy subtrees, OAuth aliases, exclusions, etc. For a tool whose killer feature is "clone, build, run," this is a UX regression. **90% of users need zero config.** + +**Fix:** Config file is optional. Sensible defaults for everything. Flags/env override the few things that vary. Config file only needed for advanced use (model remaps, API keys, multi-account). + +### Flaw 2: Embedding CLIProxyAPIPlus via Builder/Service drags in gin +The `cliproxy.Builder` and `cliproxy.Service` types import `internal/api` which imports gin — even if you never call `Run()`. But `coreauth.Manager` and the executor layer have **zero gin dependency** and can be used standalone. + +**Fix:** Embed CLIProxyAPIPlus via `cliproxy.Builder`/`Service` on a localhost ephemeral port. While this pulls in gin as a transitive dependency, it reuses the full battle-tested auth pipeline (token refresh, retry, credential selection) without reimplementing it. The gin dependency is acceptable since it only adds ~10MB and is never exposed externally. + +### Flaw 3: Default bind address `0.0.0.0` is too open +A local dev tool with auth tokens should not default to all interfaces. + +**Fix:** Default to `127.0.0.1`. + +### Flaw 4: ThinkingProxy features may not be needed +The `-thinking-N` model suffix is a vibeproxy convention, not an Amp CLI feature. We shouldn't implement middleware for conventions we haven't verified in real traffic. + +**Fix:** Cut thinking middleware from scope. Add logging to capture real model names. Implement only if verified. + +### Flaw 5: Vercel gateway bypass is vibeproxy-specific +Niche, high-coupling, hard to explain in UX. + +**Fix:** Cut from scope entirely. + +### Flaw 6: Cold-start problem +If nothing works until you create a config AND complete OAuth, the first-run experience is terrible. + +**Fix:** `amp-proxy` starts and works immediately for non-provider routes (ampcode.com, auth, tools). Provider requests without auth return a clear actionable error: *"Not authenticated. Run `amp-proxy login claude` first."* + +### Flaw 7: Auth from request handlers is dangerous +Opening browsers or printing device codes from inside an HTTP handler causes races (multiple concurrent requests → multiple auth attempts). + +**Fix:** Auth is always an explicit CLI action, never triggered from request serving. + +### Flaw 8: Hot-reload complexity for no benefit +Config hot-reload via fsnotify adds complexity. Restarting a local tool takes <1 second. + +**Fix:** No hot-reload. Read config once on startup. + +--- + +## Revised Architecture + +### Design Principles + +1. **Zero config for common use** — clone, build, `amp-proxy login claude`, `amp-proxy serve`, done +2. **Embedded auth server** — `cliproxy.Builder`/`Service` on localhost ephemeral port; gin is a transitive dep but never exposed externally +3. **Fail gracefully** — provider down? auth expired? Non-provider routes keep working, provider routes return actionable errors +4. **Reuse existing tokens** — `~/.cli-proxy-api/` works as-is for users coming from vibeproxy +5. **One binary, one port, one process** + +### Current Flow (before) + +``` +Amp CLI → amp-proxy (:18317) → vibeproxy (:8317) → CLIProxyAPIPlus (:8318) → providers + ↘ ampcode.com +``` + +### New Flow (after) + +``` +Amp CLI → amp-proxy (:18317) + ├─ /auth/* → 302 redirect to ampcode.com + ├─ /api/internal → tool stubs / Exa + ├─ /api/provider/* → provider pipeline → embedded CLIProxyAPIPlus (127.0.0.1:ephemeral) → providers + ├─ /v1/*, /api/v1/* → provider pipeline → embedded CLIProxyAPIPlus (127.0.0.1:ephemeral) → providers + └─ everything else → ampcode.com reverse proxy +``` + +**Embedded CLIProxyAPIPlus on localhost.** A `cliproxy.Service` runs on an ephemeral `127.0.0.1` port inside the process. Provider requests are reverse-proxied to it, reusing the full auth pipeline. + +--- + +## How CLIProxyAPIPlus Is Embedded + +Using `cliproxy.Builder`/`Service` to run the full auth server in-process: + +```go +import ( + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +// Load config, override host/port to localhost ephemeral +cfg, _ := config.LoadConfig(configPath) +cfg.Host = "127.0.0.1" +cfg.Port = ephemeralPort + +// Build the service +service, _ := cliproxy.NewBuilder(). + WithConfig(cfg). + WithConfigPath(configPath). + Build() + +// Run in background goroutine +go service.Run(ctx) + +// Provider requests are reverse-proxied to this address +proxyURL := fmt.Sprintf("http://127.0.0.1:%d", ephemeralPort) +``` + +### What we get for free from the embedded CLIProxyAPIPlus +- OAuth token refresh (background loop) +- Round-robin / fill-first credential selection +- Multi-account failover with cooldown +- Quota tracking per model per credential +- Retry across credentials on 429/401 + +### What we still own in amp-proxy +- HTTP server (`net/http`) +- Request routing (ampcode vs. provider) +- Google ↔ Anthropic/OpenAI protocol translation +- Exa/tool stub interception +- Request body mutations (strip cache_control, strip openai fields) +- Graceful shutdown, logging, metrics + +--- + +## Config: Optional, Minimal, Progressive + +### Zero config (most users) + +```bash +amp-proxy login claude # OAuth flow, tokens saved to ~/.cli-proxy-api/ +amp-proxy serve # starts on 127.0.0.1:18317, auto-detects tokens +``` + +That's it. Model remaps use built-in defaults. Single provider works immediately. + +### Flags/env for common overrides + +```bash +amp-proxy serve --port 9000 --addr 0.0.0.0 +# or +LISTEN_PORT=9000 amp-proxy serve +``` + +### Config file for advanced use (optional) + +Only needed for: custom model remaps, API keys, multi-account, provider toggles. + +**Location:** auto-detected via `os.UserConfigDir()` → `/amp-proxy/config.yaml` +Override: `AMP_PROXY_CONFIG=/path/to/config.yaml` or `--config` flag. + +```yaml +# Only include what you want to change. Everything has defaults. + +# Optional: override defaults +server: + addr: 127.0.0.1 + port: 18317 + +# Optional: Exa API for web_search tool +exa_api_key: ${EXA_API_KEY} + +# Optional: auth token directory (default: ~/.cli-proxy-api) +auth_dir: ~/.cli-proxy-api + +# Optional: provider routing strategy +routing: round-robin # or "fill-first" + +# Optional: direct API keys (skip OAuth entirely) +api_keys: + anthropic: + - key: ${ANTHROPIC_API_KEY} + openai: + - key: ${OPENAI_API_KEY} + base_url: https://api.openai.com + +# Optional: enable/disable providers +providers: + claude: true # default: true + openai: true # default: true + gemini: false + copilot: false + +# Optional: custom model remapping (overrides built-in defaults) +model_remaps: + - from: gemini-3-flash-preview + to: claude-sonnet-4-6 + provider: anthropic + + - from: gemini-3-flash + to: claude-sonnet-4-6 + provider: anthropic + + - from: gemini-3-pro + to: gpt-5.4 + provider: openai + +# Optional: fallback for unmapped models +fallback_model: + name: claude-sonnet-4-6 + provider: anthropic +``` + +**Key difference from v1 plan:** this config is ~30 lines max, everything is optional, and most users never create it. + +--- + +## CLI Commands + +``` +amp-proxy serve # start the proxy (default if no subcommand) +amp-proxy login # claude | openai | gemini | copilot | qwen +amp-proxy logout # remove saved tokens +amp-proxy status # show auth status for all providers +amp-proxy config init # generate commented config.yaml +amp-proxy config validate # validate config file +amp-proxy version # print version +``` + +### Login flow +`amp-proxy login claude` spawns the CLIProxyAPIPlus auth flow for the specified provider. This opens a browser for OAuth or prints a device code. Tokens are saved to `auth_dir`. The running proxy (if any) picks them up on next request (no hot-reload needed — just re-scan the token dir on demand or at a short interval). + +### Status output +``` +$ amp-proxy status +Provider Status Account Expires +───────────────────────────────────────────────────────── +claude ✓ active user@gmail.com 2025-04-15 +openai ✓ active user@gmail.com 2025-04-12 +gemini ✗ not authed +copilot ✗ not authed +``` + +--- + +## Provider Request Pipeline + +``` +incoming request + │ + ├─ [1] is this a provider request? (/api/provider/*, /v1/*, /api/v1/*) + │ no → ampcode proxy / tool stubs / etc. + │ + ├─ [2] unsupported provider? (not anthropic/openai) + │ yes → protocol translate (Google → Anthropic/OpenAI) + │ no → pass through + │ + ├─ [3] request mutations: + │ ├─ strip cache_control from body (prevents 400 via OAuth) + │ └─ strip unsupported OpenAI fields (stream_options) + │ + ├─ [4] reverse-proxy to embedded CLIProxyAPIPlus + │ ├─ selects credential (round-robin/fill-first) + │ ├─ injects auth headers + │ └─ sends to upstream provider + │ + ├─ [5] response handling: + │ ├─ success → reverse translate if remapped, forward to client + │ ├─ 401/429 → retry with different credential + │ └─ error → clear error message + │ + └─ [6] if no auth available: + → HTTP 503: {"error": "not_authenticated", "message": "Run `amp-proxy login claude` to authenticate"} +``` + +--- + +## Package Structure (Simplified) + +``` +amp-proxy/ +├── main.go # CLI entry, subcommands (serve/login/logout/status) +├── config.go # YAML loader (optional file), env expansion, defaults +├── proxy.go # HTTP handler, routing +├── gateway.go # embedded CLIProxyAPIPlus (Builder/Service/writeConfig) +├── auth.go # login/logout/status CLI, token scanning +├── remap.go # model mapping, Google protocol translation +├── translate_anthropic.go # Google ↔ Anthropic +├── translate_openai.go # Google ↔ OpenAI +├── tool_call_tracker.go # tool result ID deduplication +├── config.example.yaml +├── go.mod +└── go.sum +``` + +**Deliberately flat.** No `internal/gateway/`, no `internal/middleware/`, no deep package tree. This is a local dev tool, not a framework. + +--- + +## What's In vs. Out of Scope + +### In scope (v1) +- Embed CLIProxyAPIPlus via `cliproxy.Builder`/`Service` for credential management +- `login` / `logout` / `status` CLI subcommands +- Optional YAML config for model remaps, API keys, provider toggles +- `cache_control` stripping (needed to prevent 400s via OAuth route) +- `stream_options` stripping (already exists) +- Reuse `~/.cli-proxy-api/` token dir +- Keep all existing functionality (ampcode proxy, auth redirect, Exa tools, Google remap) + +### Out of scope (cut) +- ~~ThinkingProxy middleware~~ — verify if needed first +- ~~Vercel gateway bypass~~ — vibeproxy-specific +- ~~Config hot-reload~~ — restart is fast enough +- ~~Management dashboard~~ — not needed for local tool +- ~~Standalone gin HTTP server~~ — gin is present as a transitive dep of CLIProxyAPIPlus but only listens on localhost + +### Future (if proven needed) +- Thinking param injection (only if Amp CLI actually sends `-thinking-N`) +- Config hot-reload (only if users demand it) +- TUI for interactive provider management + +--- + +## Migration Path + +### From vibeproxy + amp-proxy (current setup) + +1. Build new amp-proxy +2. `amp-proxy status` — sees existing tokens in `~/.cli-proxy-api/` automatically +3. `amp-proxy serve` — works immediately, no config needed +4. Stop vibeproxy — no longer needed + +**Zero config migration.** Existing tokens are reused in-place. No copying, no format changes. + +### From scratch (new user) + +1. `go install github.com/johannhipp/amp-proxy@latest` +2. `amp-proxy login claude` — opens browser, OAuth completes +3. `amp-proxy serve` — ready +4. Point Amp CLI: `amp.url = http://127.0.0.1:18317` + +--- + +## Risks & Mitigations (Revised) + +| Risk | Mitigation | +|------|-----------| +| gin pulled in as transitive dep via `cliproxy.Builder` | Acceptable: adds ~10MB, only listens on 127.0.0.1 ephemeral port, never exposed externally. | +| Token refresh races during serving | CLIProxyAPIPlus handles this internally with singleflight | +| Provider down takes out whole proxy | Non-provider routes always work. Provider errors return 503 with clear message. | +| CLIProxyAPIPlus API changes (v6→v7) | Pin to specific version. Adapter layer is thin enough to update. | +| Binary size increase | Acceptable tradeoff for auth management. Measure before/after. | +| Users confused by subcommands | `amp-proxy` without args = `amp-proxy serve` (backwards compatible) | + +--- + +## Build Safety (Development) + +- **Worktree:** all dev in `../amp-proxy-v2/` (branch `feat/embed-cliproxy`) +- **Binary:** output as `amp-proxy-v2` (Makefile target), never overwrites `amp-proxy` +- **Ports:** dev/test use `:28317`, never `:18317` or `:8317` +- **Running tmux session:** never touched diff --git a/Makefile b/Makefile index 87cc880..a5fa249 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,17 @@ -.PHONY: build run dev clean test +.PHONY: build run dev clean test fmt vet help + +# Build as amp-proxy-v2 to never overwrite the running binary +BINARY := bin/amp-proxy-v2 +VERSION := $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") build: - go build -o bin/amp-proxy . + go build -ldflags "-X main.version=$(VERSION)" -o $(BINARY) . run: build - ./bin/amp-proxy + $(BINARY) dev: - go run . + go run -ldflags "-X main.version=$(VERSION)" . clean: rm -rf bin/ dist/ @@ -22,10 +26,10 @@ vet: go vet ./... help: - @echo "amp-proxy - Conditional HTTP Proxy Server" + @echo "amp-proxy v2 - Smart proxy for Amp CLI with built-in provider auth" @echo "" @echo "Available commands:" - @echo " make build - Build the binary" + @echo " make build - Build the binary (bin/amp-proxy-v2)" @echo " make run - Build and run" @echo " make dev - Run in development mode" @echo " make clean - Clean build artifacts" diff --git a/README.md b/README.md index 657c6e4..3d54a6d 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,25 @@ # amp-proxy -## Why +Use [Amp](https://ampcode.com) with your existing Claude Max, ChatGPT Plus, or Gemini subscriptions — no API billing. -I love [Amp](https://ampcode.com) but just not that big of a fan of API billing, and there's this nice MacOS statusbar app called [vibeproxy](https://github.com/automazeio/vibeproxy) that exposes existing Claude Max, ChatGPT Plus, or Gemini subscriptions via an API. So I built [amp-proxy](https://github.com/johannhipp/amp-proxy) which sits in between. +## Quick Start -[Doesn't vibeproxy already do this?](https://github.com/automazeio/vibeproxy/blob/main/AMPCODE_SETUP.md). Yes, but only partially. It only works for LLM calls of the providers you've authenticated with in vibeproxy, but +```bash +git clone https://github.com/johannhipp/amp-proxy && cd amp-proxy +./bin/install.sh +amp-proxy login claude +amp-proxy +``` + +Then point Amp at the proxy: + +```bash +echo '{"amp.url": "http://localhost:18317"}' > ~/.config/amp/settings.json +``` -1) not for other providers and models that amp will request -2) does not support server-side tools like `web_search` and `read_web_page`, where Amp has its own proprietary implementation -3) does not support OAuth, which is required for auth, GitHub, threads, settings, etc. and a few other amp-native features. See all [changes needed to make this work](#features). +## What It Does + +amp-proxy sits between the Amp CLI and the world. It routes LLM calls through your consumer subscriptions (via [CLIProxyAPIPlus](https://github.com/router-for-me/CLIProxyAPIPlus)) while keeping everything else — OAuth, threads, settings, GitHub — on ampcode.com. ``` Amp CLI @@ -16,106 +27,184 @@ Amp CLI ▼ amp-proxy (:18317) │ - ├── /api/provider/* ──► vibeproxy (:8317) LLM calls - ├── /v1/*, /api/v1/* ──► vibeproxy (:8317) LLM streaming - ├── /auth/* ──► ampcode.com OAuth (302 redirect) - └── everything else ──► ampcode.com threads, settings, GitHub, etc. + ├── /api/provider/* ──► built-in provider gateway LLM calls (Claude, GPT, Gemini, ...) + ├── /v1/*, /api/v1/* ──► built-in provider gateway LLM streaming + ├── /auth/* ──► ampcode.com OAuth (302 redirect) + ├── /api/internal ──► Exa API / stubs web_search, read_web_page + └── everything else ──► ampcode.com threads, settings, GitHub ``` +Single binary. Single port. No vibeproxy needed. + ## Features -- **Request routing** -- LLM calls go to vibeproxy, while auth, threads, and settings stay on ampcode.com -- **OAuth redirect** -- auth paths get 302'd to ampcode.com so cookies land on the right domain -- **Model remapping** -- swap unsupported models (Gemini) for ones you have (Claude, GPT), with full request/response translation across Google GenAI, Anthropic Messages, and OpenAI Chat Completions formats, including streaming -- **Stable subagent tool translation** -- fixes repeated same-name tool calls (like Finder's multiple `shell_command` steps) by tracking tool call/result IDs one-to-one, preventing duplicate `tool_result` protocol errors -- **Web search via Exa** -- intercepts Amp's `web_search` and `read_web_page` server-side tools and routes them through the [Exa API](https://exa.ai) instead of ampcode.com's credit-gated backend -- **Credit gate bypass** -- fakes `getUserFreeTierStatus` so the CLI doesn't block server-side tool dispatch -- **Structured logging** -- every request logged with slog (request ID, headers redacted, JSON body previews, route decisions, response timing) -- **Graceful shutdown** -- SIGINT/SIGTERM trigger a clean drain of in-flight requests (10s timeout) -- **Health check** -- `/healthz` endpoint for Docker, k8s, or monitoring -- **Request metrics** -- `/metrics` endpoint with JSON counters per route type +- **Built-in provider auth** — `amp-proxy login claude` authenticates with your subscription, no separate app needed +- **Multi-provider support** — Claude, OpenAI, Gemini, GitHub Copilot, Qwen — with round-robin and failover +- **Model remapping** — unsupported models (Gemini) get translated to ones you have (Claude, GPT), with full request/response protocol translation across Google GenAI, Anthropic Messages, and OpenAI Chat Completions formats, including streaming +- **Configurable** — model mappings, providers, API keys, routing strategy — all via optional YAML config +- **Web search via Exa** — intercepts Amp's `web_search` and `read_web_page` tools and routes them through the [Exa API](https://exa.ai) +- **Credit gate bypass** — fakes `getUserFreeTierStatus` so server-side tools aren't blocked +- **Stable tool translation** — fixes repeated same-name tool calls with one-to-one ID tracking +- **Zero config for common use** — works out of the box after `login`, config file only needed for advanced use -## Setup +## Install + +Requires [Go 1.21+](https://go.dev/dl/). Works on macOS and Linux (amd64/arm64). ```bash -# build -make build +git clone https://github.com/johannhipp/amp-proxy +cd amp-proxy +./bin/install.sh +``` -# run (defaults: listen on :18317, vibeproxy on :8317) -./bin/amp-proxy +This builds amp-proxy and downloads `cli-proxy-api-plus` (needed for auth flows) into `~/.local/bin`. -# with Exa web search enabled -EXA_API_KEY=your-key ./bin/amp-proxy +Override the install directory: -# with custom targets -./bin/amp-proxy --port 18317 --vibeproxy http://localhost:8317 --ampcode https://ampcode.com +```bash +AMP_PROXY_INSTALL_DIR=/usr/local/bin ./bin/install.sh ``` -Point Amp at the proxy: +## Auth + +Authenticate with one or more providers: ```bash -echo '{"amp.url": "http://localhost:18317"}' > ~/.config/amp/settings.json +amp-proxy login claude # Claude Max / Pro +amp-proxy login openai # ChatGPT Plus / Pro +amp-proxy login gemini # Gemini +amp-proxy login copilot # GitHub Copilot +amp-proxy login qwen # Qwen ``` -You'll also need the Amp API key registered for this URL. Copy your existing key: +Check status: ```bash -# check your existing keys -cat ~/.local/share/amp/secrets.json - -# add an entry for the proxy URL (same key, different URL) +amp-proxy status ``` -See `.env.example` for all configuration options. +``` +Provider Status Account Expires +──────── ────── ─────── ─────── +claude ✓ active user@gmail.com 2025-04-15 +openai ✓ active user@gmail.com 2025-04-12 +gemini ✗ not authed - - +``` -## Model remapping +Tokens are stored in `~/.cli-proxy-api/` — compatible with vibeproxy, so existing tokens carry over. -If you don't have a subscription for every provider Amp uses, amp-proxy can swap unsupported models for ones you do have access to. When Amp requests a Gemini model (and you don't have Google OAuth in vibeproxy), the proxy translates the request to a supported provider automatically. +## Usage -Default mappings in `remap.go`: +```bash +# Start with defaults (127.0.0.1:18317) +amp-proxy -| Amp requests | Gets served by | Provider | -|---|---|---| -| gemini-3-flash-preview | claude-sonnet-4-6 | Anthropic | -| gemini-3-flash | claude-sonnet-4-6 | Anthropic | -| gemini-3-pro | gpt-5.4 | OpenAI | -| gemini-3-pro-image | gpt-image-1 | OpenAI | -| anything else unsupported | claude-sonnet-4-6 | Anthropic | +# Custom port +amp-proxy serve --port 9000 -The translation handles request/response format conversion between Google GenAI, Anthropic Messages, and OpenAI Chat Completions APIs. Streaming works too. +# With Exa web search +EXA_API_KEY=your-key amp-proxy -To change the mappings, edit the `modelMappings` slice in `remap.go`. Unmapped models fall back to Sonnet 4.6 with a warning in the logs so you know what to add. +# Debug logging +amp-proxy serve --debug +``` -## Server-side tool interception +## Configuration -Amp executes `web_search` and `read_web_page` server-side on ampcode.com, gated by a credit check. If your account has no credits, both tools fail. amp-proxy intercepts the `/api/internal` RPC calls and routes them to the Exa API instead. +Config is **optional**. amp-proxy works with zero config for most users. An optional YAML file unlocks advanced features. -Three intercepts, in order: +Auto-detected at `~/.config/amp-proxy/config.yaml` (macOS: `~/Library/Application Support/amp-proxy/config.yaml`), or set `AMP_PROXY_CONFIG`. -1. **`getUserFreeTierStatus` fake** -- The CLI polls this every 30s. ampcode.com returns `canUseAmpFree: false` when credits are exhausted, blocking tool dispatch. The proxy returns `canUseAmpFree: true` to unblock the client-side gate. +```yaml +# Only include what you want to change. Everything has defaults. -2. **`webSearch2` → Exa `/search`** -- `web_search` tool calls become `POST /api/internal?webSearch2` with `{"method":"webSearch2","params":{"objective":"...","maxResults":N}}`. The proxy translates this to an Exa search request and returns results in the schema the CLI expects (`result.results[]` with `title`, `url`, `text`). +# Exa API for web_search tool +exa_api_key: ${EXA_API_KEY} -3. **`extractWebPageContent` → Exa `/contents`** -- `read_web_page` calls become `POST /api/internal?extractWebPageContent` with `{"method":"extractWebPageContent","params":{"url":"...","objective":"..."}}`. The proxy fetches the page via Exa and returns content as `result.excerpts[]`. +# Provider routing: "round-robin" (default) or "fill-first" +routing: fill-first -Set `EXA_API_KEY` to enable. Without it, both tools return a stub message instead of failing. +# Custom model remapping +model_remaps: + - from: gemini-3-flash-preview + to: claude-sonnet-4-6 + provider: anthropic -## Logging + - from: gemini-3-pro + to: gpt-5.4 + provider: openai -Structured logs via `slog`. Every request includes request ID, method, path, route decision, and response timing. +# Direct API keys (skip OAuth) +anthropic_api_keys: + - key: ${ANTHROPIC_API_KEY} +# Enable/disable providers +providers: + claude: true + openai: true + gemini: false ``` -INFO request reqID=14 method=POST path=/api/provider/anthropic/v1/messages -INFO route reqID=14 label=VIBEPROXY method=POST path=/api/provider/anthropic/v1/messages target=http://localhost:8317 rule=provider-to-vibeproxy -INFO response reqID=14 label=VIBEPROXY status=200 statusText=OK bytes=8450 elapsed=15622ms + +See [`config.example.yaml`](config.example.yaml) for all options. + +## Prompt Caching + +By default, amp-proxy strips `cache_control` fields from request bodies before forwarding to providers. This prevents 400 errors on some OAuth routes, but **disables Anthropic prompt caching** — which can significantly increase token usage in long sessions. + +To enable prompt caching, set `strip_cache_control: false` in your config: + +```yaml +strip_cache_control: false ``` +If you see 400 errors from Anthropic after disabling this, re-enable it. + +## Model Remapping + +When Amp requests a model you don't have (e.g., Gemini), amp-proxy translates the request to a provider you do have. This includes full protocol translation — request body, response body, and streaming — across Google GenAI, Anthropic Messages, and OpenAI Chat Completions formats. + +Default mappings (configurable via YAML): + +| Amp requests | Served by | Provider | +|---|---|---| +| gemini-3-flash-preview | claude-sonnet-4-6 | Anthropic | +| gemini-3-flash | claude-sonnet-4-6 | Anthropic | +| gemini-3-pro | gpt-5.4 | OpenAI | +| gemini-3-pro-image | gpt-image-1 | OpenAI | +| anything else unsupported | claude-sonnet-4-6 | Anthropic | + +## Web Search + +Amp's `web_search` and `read_web_page` are server-side tools on ampcode.com, gated by credits. amp-proxy intercepts them and routes through the [Exa API](https://exa.ai) instead. + +Set `EXA_API_KEY` to enable. Without it, tools return a stub message. + ## Endpoints | Path | Method | Description | |------|--------|-------------| -| `/healthz` | GET | Health check, returns `{"status":"ok"}` | -| `/metrics` | GET | Request counters (total, vibeproxy, ampcode, remap, exa, errors) | +| `/healthz` | GET | Health check → `{"status":"ok"}` | +| `/metrics` | GET | Request counters (JSON) | + +## CLI Reference + +``` +amp-proxy [command] [flags] + +Commands: + serve Start the proxy server (default) + login Authenticate with a provider + logout Remove saved auth for a provider + status Show auth status for all providers + version Print version info + +Flags (serve): + --port Listen port (default: 18317, env: LISTEN_PORT) + --addr Listen address (default: 127.0.0.1, env: LISTEN_ADDR) + --config Config file path (env: AMP_PROXY_CONFIG) + --ampcode Ampcode URL (default: https://ampcode.com) + --debug Enable debug logging (env: AMP_PROXY_DEBUG) +``` ## Development @@ -126,26 +215,20 @@ make fmt # format code make vet # run go vet ``` -## Releases +## Migrating from vibeproxy -Releases are fully automated via [`.github/workflows/release.yml`](.github/workflows/release.yml) and [`.goreleaser.yml`](.goreleaser.yml). +If you're already using vibeproxy + the old amp-proxy: -- Push a tag like `v0.1.0` to build and publish a GitHub Release. -- Run the workflow manually (`workflow_dispatch`) to validate a snapshot release without publishing artifacts. +1. Install the new amp-proxy: `./bin/install.sh` +2. `amp-proxy status` — your existing tokens in `~/.cli-proxy-api/` are picked up automatically +3. `amp-proxy` — start the proxy, vibeproxy is no longer needed +4. Stop vibeproxy -Tag-based release: +## Releases + +Automated via [`.github/workflows/release.yml`](.github/workflows/release.yml) and [`.goreleaser.yml`](.goreleaser.yml). ```bash -git tag v0.1.0 -git push origin v0.1.0 +git tag v0.2.0 +git push origin v0.2.0 ``` - -## Dependency Updates - -Renovate is configured in [`.github/renovate.json`](.github/renovate.json) for this repository. - -- Go module updates are grouped and run `go mod tidy` after updates. -- GitHub Actions updates are grouped to reduce PR noise. -- Major updates are labeled separately for easier review. - -To enable it, install the Renovate GitHub App for this repository. diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..bcc2165 --- /dev/null +++ b/auth.go @@ -0,0 +1,250 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "text/tabwriter" + "time" +) + +type providerDef struct { + name string + loginFlag string + tokenType string +} + +var providerDefs = []providerDef{ + {"claude", "-claude-login", "claude"}, + {"openai", "-codex-login", "codex"}, + {"gemini", "-login", "gemini-cli"}, + {"copilot", "-github-copilot-login", "github-copilot"}, + {"qwen", "-qwen-login", "qwen"}, + {"antigravity", "-antigravity-login", "antigravity"}, +} + +var providerLoginFlags = map[string]string{} +var providerTokenTypes = map[string]string{} + +func init() { + for _, p := range providerDefs { + providerLoginFlags[p.name] = p.loginFlag + providerTokenTypes[p.name] = p.tokenType + } + // Aliases + providerLoginFlags["codex"] = providerLoginFlags["openai"] + providerTokenTypes["codex"] = providerTokenTypes["openai"] +} + +// tokenFile represents a parsed auth token file +type tokenFile struct { + Type string `json:"type"` + Email string `json:"email"` + Disabled bool `json:"disabled"` + Expired string `json:"expired"` + Filename string `json:"-"` +} + +// RunLogin runs the OAuth login flow for a provider +func RunLogin(provider, authDir string) error { + flag, ok := providerLoginFlags[provider] + if !ok { + return fmt.Errorf("unknown provider %q\nSupported: claude, openai, gemini, copilot, qwen, antigravity", provider) + } + + binary := findCLIProxyBinary() + if binary == "" { + return fmt.Errorf("cli-proxy-api-plus binary not found\n\nPlease install it:\n go install github.com/router-for-me/CLIProxyAPI/v6@latest\n\nOr download from: https://github.com/router-for-me/CLIProxyAPIPlus/releases") + } + + authDir = ExpandHome(authDir) + if err := os.MkdirAll(authDir, 0700); err != nil { + return fmt.Errorf("create auth dir: %w", err) + } + + fmt.Printf("Logging in to %s...\n", provider) + fmt.Printf("Using binary: %s\n", binary) + + // Build a minimal config for the login command + configPath, err := writeMinimalConfig(authDir) + if err != nil { + return fmt.Errorf("write temp config: %w", err) + } + defer os.Remove(configPath) + + cmd := exec.Command(binary, flag, "-config", configPath) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("login command failed: %w", err) + } + + fmt.Printf("\n✓ Login complete. Tokens saved to %s\n", authDir) + return nil +} + +// RunLogout removes auth tokens for a provider +func RunLogout(provider, authDir string) error { + tokenType, ok := providerTokenTypes[provider] + if !ok { + return fmt.Errorf("unknown provider %q", provider) + } + + authDir = ExpandHome(authDir) + tokens, err := scanTokenFiles(authDir) + if err != nil { + return fmt.Errorf("scan tokens: %w", err) + } + + removed := 0 + for _, t := range tokens { + if t.Type == tokenType { + path := filepath.Join(authDir, t.Filename) + if err := os.Remove(path); err != nil { + fmt.Fprintf(os.Stderr, "warning: failed to remove %s: %v\n", t.Filename, err) + } else { + removed++ + } + } + } + + if removed == 0 { + fmt.Printf("No tokens found for %s\n", provider) + } else { + fmt.Printf("✓ Removed %d token(s) for %s\n", removed, provider) + } + return nil +} + +// RunStatus displays auth status for all providers +func RunStatus(authDir string) { + authDir = ExpandHome(authDir) + tokens, err := scanTokenFiles(authDir) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to scan auth directory %s: %v\n", authDir, err) + fmt.Fprintf(os.Stderr, "Run 'amp-proxy login ' to authenticate.\n") + return + } + + // Group by provider type + byType := make(map[string][]tokenFile) + for _, t := range tokens { + byType[t.Type] = append(byType[t.Type], t) + } + + fmt.Printf("Auth directory: %s\n\n", authDir) + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "Provider\tStatus\tAccount\tExpires") + fmt.Fprintln(w, "────────\t──────\t───────\t───────") + + for _, p := range providerDefs { + tokens, ok := byType[p.tokenType] + if !ok || len(tokens) == 0 { + fmt.Fprintf(w, "%s\t✗ not authed\t-\t-\n", p.name) + continue + } + for _, t := range tokens { + status := "✓ active" + if t.Disabled { + status = "⊘ disabled" + } + if t.Expired != "" { + if exp, err := time.Parse(time.RFC3339, t.Expired); err == nil { + if exp.Before(time.Now()) { + status = "✗ expired" + } + } + } + email := t.Email + if email == "" { + email = "-" + } + expires := "-" + if t.Expired != "" { + if exp, err := time.Parse(time.RFC3339, t.Expired); err == nil { + expires = exp.Format("2006-01-02") + } + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", p.name, status, email, expires) + } + } + + w.Flush() +} + +// scanTokenFiles reads all JSON token files in the auth directory +func scanTokenFiles(authDir string) ([]tokenFile, error) { + entries, err := os.ReadDir(authDir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var tokens []tokenFile + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".json") { + continue + } + data, err := os.ReadFile(filepath.Join(authDir, e.Name())) + if err != nil { + continue + } + var t tokenFile + if err := json.Unmarshal(data, &t); err != nil { + continue + } + t.Filename = e.Name() + if t.Type != "" { + tokens = append(tokens, t) + } + } + return tokens, nil +} + +// findCLIProxyBinary searches for the cli-proxy-api-plus binary +func findCLIProxyBinary() string { + // Check PATH + if p, err := exec.LookPath("cli-proxy-api-plus"); err == nil { + return p + } + + // Check common locations + home, _ := os.UserHomeDir() + candidates := []string{ + filepath.Join(home, "go", "bin", "cli-proxy-api-plus"), + filepath.Join(home, ".local", "bin", "cli-proxy-api-plus"), + "/usr/local/bin/cli-proxy-api-plus", + } + + for _, p := range candidates { + if _, err := os.Stat(p); err == nil { + return p + } + } + + return "" +} + +// writeMinimalConfig creates a temporary config file for the login command +func writeMinimalConfig(authDir string) (string, error) { + config := fmt.Sprintf("port: 0\nhost: 127.0.0.1\nauth-dir: %q\ndebug: false\n", authDir) + f, err := os.CreateTemp("", "amp-proxy-auth-*.yaml") + if err != nil { + return "", err + } + if _, err := f.WriteString(config); err != nil { + f.Close() + os.Remove(f.Name()) + return "", err + } + f.Close() + return f.Name(), nil +} diff --git a/bin/install.sh b/bin/install.sh new file mode 100755 index 0000000..631bc1d --- /dev/null +++ b/bin/install.sh @@ -0,0 +1,153 @@ +#!/usr/bin/env bash +set -euo pipefail + +# amp-proxy installer +# Builds amp-proxy from source, downloads cli-proxy-api-plus for auth, +# and installs both to ~/.local/bin. + +INSTALL_DIR="${AMP_PROXY_INSTALL_DIR:-$HOME/.local/bin}" +CLIPROXY_VERSION="${CLIPROXY_VERSION:-latest}" + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[0;33m' +BLUE='\033[0;34m' +BOLD='\033[1m' +RESET='\033[0m' + +info() { echo -e "${BLUE}==>${RESET} ${BOLD}$*${RESET}"; } +ok() { echo -e "${GREEN} ✓${RESET} $*"; } +warn() { echo -e "${YELLOW} !${RESET} $*"; } +fail() { echo -e "${RED} ✗${RESET} $*"; exit 1; } + +# --- Detect platform --- + +OS="$(uname -s | tr '[:upper:]' '[:lower:]')" +ARCH="$(uname -m)" + +case "$OS" in + darwin) ;; + linux) ;; + *) fail "Unsupported OS: $OS (macOS and Linux only)" ;; +esac + +case "$ARCH" in + x86_64) ARCH="amd64" ;; + aarch64) ARCH="arm64" ;; + arm64) ARCH="arm64" ;; + *) fail "Unsupported architecture: $ARCH" ;; +esac + +info "Platform: ${OS}/${ARCH}" + +# --- Check Go --- + +if ! command -v go &>/dev/null; then + fail "Go is required but not found. Install it: https://go.dev/dl/" +fi + +GO_VERSION="$(go version | grep -oE 'go[0-9]+\.[0-9]+' | head -1)" +info "Go: ${GO_VERSION}" + +# --- Build amp-proxy --- + +info "Building amp-proxy..." + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" + +if [ ! -f "$REPO_DIR/go.mod" ]; then + fail "Run this script from the amp-proxy repo (expected go.mod at $REPO_DIR/go.mod)" +fi + +VERSION="$(cd "$REPO_DIR" && git describe --tags --always --dirty 2>/dev/null || echo 'dev')" + +(cd "$REPO_DIR" && go build -ldflags "-s -w -X main.version=${VERSION}" -o "$REPO_DIR/bin/amp-proxy" .) +ok "Built amp-proxy ${VERSION}" + +# --- Download cli-proxy-api-plus --- + +info "Downloading cli-proxy-api-plus (needed for 'amp-proxy login')..." + +if [ "$CLIPROXY_VERSION" = "latest" ]; then + CLIPROXY_TAG="$(curl -sL https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest | grep '"tag_name"' | head -1 | cut -d'"' -f4)" +else + CLIPROXY_TAG="$CLIPROXY_VERSION" +fi + +if [ -z "$CLIPROXY_TAG" ]; then + fail "Could not determine CLIProxyAPIPlus release version" +fi + +# Tag format: v6.9.4-2 → version 6.9.4-2 +CLIPROXY_VER="${CLIPROXY_TAG#v}" +TARBALL="CLIProxyAPIPlus_${CLIPROXY_VER}_${OS}_${ARCH}.tar.gz" +DOWNLOAD_URL="https://github.com/router-for-me/CLIProxyAPIPlus/releases/download/${CLIPROXY_TAG}/${TARBALL}" + +TMPDIR="$(mktemp -d)" +trap 'rm -rf "$TMPDIR"' EXIT + +if ! curl -sL --fail -o "$TMPDIR/$TARBALL" "$DOWNLOAD_URL"; then + fail "Download failed: $DOWNLOAD_URL" +fi + +tar -xzf "$TMPDIR/$TARBALL" -C "$TMPDIR" cli-proxy-api-plus +ok "Downloaded cli-proxy-api-plus ${CLIPROXY_TAG}" + +# --- Install --- + +info "Installing to ${INSTALL_DIR}..." +mkdir -p "$INSTALL_DIR" + +cp "$REPO_DIR/bin/amp-proxy" "$INSTALL_DIR/amp-proxy" +chmod +x "$INSTALL_DIR/amp-proxy" +ok "Installed amp-proxy" + +cp "$TMPDIR/cli-proxy-api-plus" "$INSTALL_DIR/cli-proxy-api-plus" +chmod +x "$INSTALL_DIR/cli-proxy-api-plus" +ok "Installed cli-proxy-api-plus" + +# --- Check PATH --- + +if ! echo "$PATH" | tr ':' '\n' | grep -qx "$INSTALL_DIR"; then + warn "${INSTALL_DIR} is not in your PATH" + echo "" + echo " Add it to your shell profile:" + echo "" + if [ -f "$HOME/.zshrc" ]; then + echo " echo 'export PATH=\"${INSTALL_DIR}:\$PATH\"' >> ~/.zshrc && source ~/.zshrc" + else + echo " echo 'export PATH=\"${INSTALL_DIR}:\$PATH\"' >> ~/.bashrc && source ~/.bashrc" + fi + echo "" +fi + +# --- Done --- + +echo "" +echo -e "${GREEN}${BOLD}Installation complete!${RESET}" +echo "" +echo " Next steps:" +echo "" +echo " 1. Authenticate with your provider(s):" +echo "" +echo -e " ${BOLD}amp-proxy login claude${RESET} # Claude Max / Pro" +echo -e " ${BOLD}amp-proxy login openai${RESET} # ChatGPT Plus / Pro" +echo "" +echo " 2. Start the proxy:" +echo "" +echo -e " ${BOLD}amp-proxy${RESET}" +echo "" +echo " 3. Point Amp at the proxy:" +echo "" +echo -e " ${BOLD}echo '{\"amp.url\": \"http://localhost:18317\"}' > ~/.config/amp/settings.json${RESET}" +echo "" +echo " Optional: enable web search (get a key at https://exa.ai):" +echo "" +echo -e " ${BOLD}EXA_API_KEY=your-key amp-proxy${RESET}" +echo "" +echo " Check auth status anytime:" +echo "" +echo -e " ${BOLD}amp-proxy status${RESET}" +echo "" diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..49a9d1e --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,81 @@ +# amp-proxy configuration +# This file is OPTIONAL. amp-proxy works with zero config for most users. +# Only include settings you want to change from defaults. + +# Server settings +# port: 18317 # default: 18317 +# addr: 127.0.0.1 # default: 127.0.0.1 (localhost only) +# ampcode_url: https://ampcode.com + +# Exa API for web_search and read_web_page tools (optional) +# Get your key at https://dashboard.exa.ai/api-keys +# exa_api_key: ${EXA_API_KEY} + +# Auth token directory (default: ~/.cli-proxy-api) +# This is where OAuth tokens are stored. Compatible with vibeproxy. +# auth_dir: ~/.cli-proxy-api + +# Provider routing strategy: "round-robin" (default) or "fill-first" +# routing: round-robin + +# Provider request retry count +# request_retry: 3 + +# Request timeout in minutes (for long streaming responses) +# request_timeout_mins: 10 + +# Enable/disable providers (all enabled by default) +# providers: +# claude: true +# openai: true +# gemini: false +# copilot: false + +# Model remapping (Google GenAI → supported providers) +# When Amp requests a Gemini model, amp-proxy translates the request +# to the target provider's wire format (full protocol translation). +# model_remaps: +# - from: gemini-3-flash-preview +# to: claude-sonnet-4-6 +# provider: anthropic +# +# - from: gemini-3-flash +# to: claude-sonnet-4-6 +# provider: anthropic +# +# - from: gemini-3-pro +# to: gpt-5.4 +# provider: openai +# +# - from: gemini-3-pro-image +# to: gpt-image-1 +# provider: openai + +# Fallback for unmapped models +# fallback: +# to: claude-sonnet-4-6 +# provider: anthropic + +# Strip cache_control fields from request bodies (default: true) +# Prevents 400 errors on some OAuth routes, but disables Anthropic prompt caching. +# Set to false to enable prompt caching (reduces token usage in long sessions). +# strip_cache_control: true + +# Direct API keys (skip OAuth, use your own API keys) +# anthropic_api_keys: +# - key: ${ANTHROPIC_API_KEY} +# priority: 10 +# +# openai_api_keys: +# - key: ${OPENAI_API_KEY} +# base_url: https://api.openai.com +# +# gemini_api_keys: +# - key: ${GEMINI_API_KEY} + +# OpenAI-compatible providers (e.g., OpenRouter, local models) +# openai_compatible: +# - name: openrouter +# base_url: https://openrouter.ai/api/v1 +# keys: +# - key: ${OPENROUTER_API_KEY} diff --git a/config.go b/config.go index 355100f..fc057a1 100644 --- a/config.go +++ b/config.go @@ -1,60 +1,221 @@ package main import ( - "net/http" + "fmt" + "os" + "path/filepath" + "regexp" "strings" + + "gopkg.in/yaml.v3" ) -// Rule represents a proxy rule with matching criteria -type Rule struct { - Name string - Match func(*http.Request) bool // Matcher function - Target string // Target URL, empty = reject - Priority int // Higher priority rules evaluated first -} - -// Config holds proxy server configuration -type Config struct { - ListenPort int - ListenAddr string - DefaultTarget string // Where non-matched requests go (ampcode.com) - VibeProxyTarget string // Where LLM provider requests go (vibeproxy) - ExaAPIKey string // Exa API key for web search/page reading (optional) - Rules []Rule -} - -// hasPathPrefix checks if the request path matches a prefix. -func hasPathPrefix(path, prefix string) bool { - return path == prefix || strings.HasPrefix(path, prefix+"/") -} - -// NewDefaultConfig returns default configuration. -// Only LLM provider/streaming paths go to vibeproxy. -// Everything else (auth, threads, settings, GitHub, etc.) -// goes directly to ampcode.com. -func NewDefaultConfig() *Config { - const ( - ampcode = "https://ampcode.com" - vibeproxy = "http://localhost:8317" - ) - - return &Config{ - ListenPort: 18317, - ListenAddr: "0.0.0.0", - DefaultTarget: ampcode, - VibeProxyTarget: vibeproxy, - Rules: []Rule{ - // LLM provider requests -> vibeproxy - { - Name: "provider-to-vibeproxy", - Match: func(r *http.Request) bool { - return hasPathPrefix(r.URL.Path, "/api/provider") || - hasPathPrefix(r.URL.Path, "/v1") || - hasPathPrefix(r.URL.Path, "/api/v1") - }, - Target: vibeproxy, - Priority: 100, - }, +// AppConfig holds all amp-proxy configuration +type AppConfig struct { + // Server settings + ListenPort int `yaml:"port"` + ListenAddr string `yaml:"addr"` + AmpcodeURL string `yaml:"ampcode_url"` + + // Exa API for web_search and read_web_page + ExaAPIKey string `yaml:"exa_api_key"` + + // Auth token directory + AuthDir string `yaml:"auth_dir"` + + // Provider routing + Routing string `yaml:"routing"` // "round-robin" or "fill-first" + + // Provider enable/disable + Providers map[string]bool `yaml:"providers"` + + // Model remapping (Google GenAI → supported providers) + ModelRemaps []ModelRemapConfig `yaml:"model_remaps"` + + // Fallback for unmapped models + Fallback *ModelRemapConfig `yaml:"fallback"` + + // CLIProxyAPIPlus-specific config overrides + RequestRetry int `yaml:"request_retry"` + RequestTimeoutMins int `yaml:"request_timeout_mins"` + MaxRetryCredentials int `yaml:"max_retry_credentials"` + + // Direct API keys (skip OAuth) + AnthropicAPIKeys []APIKeyConfig `yaml:"anthropic_api_keys"` + OpenAIAPIKeys []APIKeyConfig `yaml:"openai_api_keys"` + GeminiAPIKeys []APIKeyConfig `yaml:"gemini_api_keys"` + + // OpenAI-compatible providers + OpenAICompatible []OpenAICompatConfig `yaml:"openai_compatible"` + + // Strip cache_control fields from request bodies before forwarding. + // Prevents 400 errors on some OAuth routes, but disables Anthropic prompt caching. + // Set to false to preserve cache_control and enable prompt caching. + StripCacheControl *bool `yaml:"strip_cache_control"` +} + +// ModelRemapConfig defines how to remap an unsupported model +type ModelRemapConfig struct { + From string `yaml:"from"` + To string `yaml:"to"` + Provider string `yaml:"provider"` // "anthropic" or "openai" +} + +// APIKeyConfig holds a direct API key configuration +type APIKeyConfig struct { + Key string `yaml:"key"` + BaseURL string `yaml:"base_url"` + Priority int `yaml:"priority"` +} + +// OpenAICompatConfig holds an OpenAI-compatible provider config +type OpenAICompatConfig struct { + Name string `yaml:"name"` + BaseURL string `yaml:"base_url"` + Keys []APIKeyConfig `yaml:"keys"` +} + +// DefaultConfig returns configuration with sensible defaults +func DefaultConfig() *AppConfig { + return &AppConfig{ + ListenPort: 18317, + ListenAddr: "127.0.0.1", + AmpcodeURL: "https://ampcode.com", + AuthDir: "~/.cli-proxy-api", + Routing: "round-robin", + Providers: map[string]bool{ + "claude": true, + "openai": true, + "gemini": true, + "copilot": true, }, + ModelRemaps: []ModelRemapConfig{ + {From: "gemini-3-flash-preview", To: "claude-sonnet-4-6", Provider: "anthropic"}, + {From: "gemini-3-flash", To: "claude-sonnet-4-6", Provider: "anthropic"}, + {From: "gemini-3-pro", To: "gpt-5.4", Provider: "openai"}, + {From: "gemini-3-pro-image", To: "gpt-image-1", Provider: "openai"}, + }, + Fallback: &ModelRemapConfig{ + To: "claude-sonnet-4-6", + Provider: "anthropic", + }, + RequestRetry: 3, + RequestTimeoutMins: 10, + StripCacheControl: boolPtr(true), + } +} + +// LoadConfig loads configuration from the given path, auto-detects, or returns defaults +func LoadConfig(path string) (*AppConfig, error) { + cfg := DefaultConfig() + + // Determine config path + if path == "" { + path = findConfigFile() + } + + if path == "" { + // No config file — use defaults + env vars + applyEnvOverrides(cfg) + return cfg, nil + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + applyEnvOverrides(cfg) + return cfg, nil + } + return nil, fmt.Errorf("read config %s: %w", path, err) + } + + // Expand ${ENV_VAR} references + expanded := expandEnvVars(string(data)) + + if err := yaml.Unmarshal([]byte(expanded), cfg); err != nil { + return nil, fmt.Errorf("parse config %s: %w", path, err) + } + + applyEnvOverrides(cfg) + return cfg, nil +} + +// findConfigFile looks for config in standard locations +func findConfigFile() string { + // Check env var + if v := os.Getenv("AMP_PROXY_CONFIG"); v != "" { + return v + } + + // Check os.UserConfigDir()/amp-proxy/config.yaml + if dir, err := os.UserConfigDir(); err == nil { + p := filepath.Join(dir, "amp-proxy", "config.yaml") + if _, err := os.Stat(p); err == nil { + return p + } + } + + // Check ./config.yaml (current directory) + if _, err := os.Stat("config.yaml"); err == nil { + return "config.yaml" + } + + return "" +} + +// applyEnvOverrides applies environment variable overrides to config +func applyEnvOverrides(cfg *AppConfig) { + if v := os.Getenv("EXA_API_KEY"); v != "" { + cfg.ExaAPIKey = v + } +} + +func boolPtr(b bool) *bool { return &b } + +var envVarRegex = regexp.MustCompile(`\$\{([^}]+)\}`) + +// expandEnvVars replaces ${VAR} with the environment variable value +func expandEnvVars(s string) string { + return envVarRegex.ReplaceAllStringFunc(s, func(match string) string { + varName := match[2 : len(match)-1] + if val, ok := os.LookupEnv(varName); ok { + return val + } + return match + }) +} + +// ExpandHome expands ~ to the user's home directory +func ExpandHome(path string) string { + if strings.HasPrefix(path, "~/") { + if home, err := os.UserHomeDir(); err == nil { + return filepath.Join(home, path[2:]) + } + } + return path +} + +// TargetPathForProvider returns the CLIProxyAPIPlus API path for a provider +func TargetPathForProvider(provider string) string { + switch provider { + case "anthropic": + return "/api/provider/anthropic/v1/messages" + case "openai": + return "/api/provider/openai/v1/chat/completions" + default: + return "/api/provider/" + provider + "/v1/messages" + } +} + +// FindModelRemap returns the remap config for a given model, or the fallback +func (cfg *AppConfig) FindModelRemap(model string) (ModelRemapConfig, bool) { + for _, m := range cfg.ModelRemaps { + if m.From == model { + return m, true + } + } + if cfg.Fallback != nil { + return *cfg.Fallback, false } + return ModelRemapConfig{}, false } diff --git a/gateway.go b/gateway.go new file mode 100644 index 0000000..ba195aa --- /dev/null +++ b/gateway.go @@ -0,0 +1,349 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "path/filepath" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "gopkg.in/yaml.v3" +) + +// ProviderGateway manages the embedded CLIProxyAPIPlus service +type ProviderGateway struct { + service *cliproxy.Service + proxy *httputil.ReverseProxy + client *http.Client // for direct requests (remap path); no total timeout for streaming + targetURL string + port int + ready bool + mu sync.RWMutex + configPath string // generated config file, cleaned up on shutdown +} + +// NewProviderGateway creates and configures the embedded CLIProxyAPIPlus service +func NewProviderGateway(appCfg *AppConfig) (*ProviderGateway, error) { + gw := &ProviderGateway{} + + // Find a free port for the internal CLIProxyAPIPlus server. + // NOTE: There is a TOCTOU race between closing this listener and + // CLIProxyAPIPlus re-binding to the same port — another process could + // claim it in between. The CLIProxyAPIPlus SDK does not support + // accepting a pre-opened listener, so we accept this minor risk. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, fmt.Errorf("find free port: %w", err) + } + gw.port = listener.Addr().(*net.TCPAddr).Port + listener.Close() + + gw.targetURL = fmt.Sprintf("http://127.0.0.1:%d", gw.port) + + // Write a CLIProxyAPIPlus config file + configPath, err := gw.writeConfig(appCfg) + if err != nil { + return nil, fmt.Errorf("write cliproxy config: %w", err) + } + gw.configPath = configPath + + // Load the config through CLIProxyAPIPlus's loader + cfg, err := config.LoadConfig(configPath) + if err != nil { + return nil, fmt.Errorf("load cliproxy config: %w", err) + } + + // Override port to our ephemeral port + cfg.Port = gw.port + cfg.Host = "127.0.0.1" + + // Build the service + svc, err := cliproxy.NewBuilder(). + WithConfig(cfg). + WithConfigPath(configPath). + Build() + if err != nil { + return nil, fmt.Errorf("build cliproxy service: %w", err) + } + gw.service = svc + + // Set up shared transport for both the reverse proxy and direct client + transport := &http.Transport{ + ResponseHeaderTimeout: 10 * time.Minute, + IdleConnTimeout: 90 * time.Second, + MaxIdleConnsPerHost: 20, + } + + // Reverse proxy for normal provider requests + target, _ := url.Parse(gw.targetURL) + gw.proxy = httputil.NewSingleHostReverseProxy(target) + gw.proxy.Transport = transport + gw.proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + slog.Error("provider gateway error", "path", r.URL.Path, "error", err) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadGateway) + fmt.Fprintf(w, `{"error":"provider_unavailable","message":"Provider backend error: %s"}`, err.Error()) + } + + // Direct client for remap path — no total timeout so SSE streams aren't killed + gw.client = &http.Client{Transport: transport} + + return gw, nil +} + +// Start starts the embedded CLIProxyAPIPlus service in a background goroutine +func (gw *ProviderGateway) Start(ctx context.Context) error { + errCh := make(chan error, 1) + readyCh := make(chan bool, 1) + + go func() { + slog.Info("provider gateway starting", "port", gw.port) + if err := gw.service.Run(ctx); err != nil && ctx.Err() == nil { + errCh <- err + } + close(errCh) + }() + + go func() { + readyCh <- gw.waitForReady(ctx, 15*time.Second) + }() + + select { + case err := <-errCh: + if err != nil { + return fmt.Errorf("provider gateway failed to start: %w", err) + } + return fmt.Errorf("provider gateway exited unexpectedly") + case ready := <-readyCh: + if !ready { + return fmt.Errorf("provider gateway did not become ready within timeout") + } + } + + gw.mu.Lock() + gw.ready = true + gw.mu.Unlock() + + slog.Info("provider gateway ready", "url", gw.targetURL) + return nil +} + +// waitForReady polls the internal server until it responds +func (gw *ProviderGateway) waitForReady(ctx context.Context, timeout time.Duration) bool { + deadline := time.After(timeout) + client := &http.Client{Timeout: 500 * time.Millisecond} + + for { + select { + case <-ctx.Done(): + return false + case <-deadline: + return false + default: + resp, err := client.Get(gw.targetURL + "/v1/models") + if err == nil { + resp.Body.Close() + if resp.StatusCode < 500 { + return true + } + } + time.Sleep(100 * time.Millisecond) + } + } +} + +// Do sends a request through the gateway's shared transport (for remap path). +// Unlike ph.httpClient, this has no total timeout so SSE streams work correctly. +func (gw *ProviderGateway) Do(req *http.Request) (*http.Response, error) { + return gw.client.Do(req) +} + +// IsReady returns whether the gateway is ready to accept requests +func (gw *ProviderGateway) IsReady() bool { + gw.mu.RLock() + defer gw.mu.RUnlock() + return gw.ready +} + +// ServeHTTP forwards the request to the embedded CLIProxyAPIPlus +func (gw *ProviderGateway) ServeHTTP(w http.ResponseWriter, r *http.Request) { + gw.mu.RLock() + ready := gw.ready + gw.mu.RUnlock() + + if !ready { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprint(w, `{"error":"provider_not_ready","message":"Provider gateway is starting up. Please retry."}`) + return + } + + gw.proxy.ServeHTTP(w, r) +} + +// Shutdown gracefully stops the embedded service and cleans up +func (gw *ProviderGateway) Shutdown(ctx context.Context) error { + if gw.configPath != "" { + if err := os.Remove(gw.configPath); err != nil && !os.IsNotExist(err) { + slog.Warn("failed to remove temp cliproxy config", "path", gw.configPath, "error", err) + } + } + if gw.service != nil { + return gw.service.Shutdown(ctx) + } + return nil +} + +type cliproxyConfig struct { + Port int `yaml:"port"` + Host string `yaml:"host"` + AuthDir string `yaml:"auth-dir"` + Debug bool `yaml:"debug"` + UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled"` + RequestRetry int `yaml:"request-retry"` + RequestTimeout string `yaml:"request-timeout,omitempty"` + MaxRetryCredentials int `yaml:"max-retry-credentials,omitempty"` + Routing *cliproxyRouting `yaml:"routing,omitempty"` + RemoteManagement cliproxyRemoteMgmt `yaml:"remote-management"` + Ampcode cliproxyAmpcode `yaml:"ampcode"` + QuotaExceeded cliproxyQuotaExceeded `yaml:"quota-exceeded"` + ClaudeAPIKey []cliproxyAPIKey `yaml:"claude-api-key,omitempty"` + CodexAPIKey []cliproxyAPIKey `yaml:"codex-api-key,omitempty"` + GeminiAPIKey []cliproxyAPIKey `yaml:"gemini-api-key,omitempty"` + OpenAICompatibility []cliproxyOpenAICompat `yaml:"openai-compatibility,omitempty"` + OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty"` +} + +type cliproxyRouting struct { + Strategy string `yaml:"strategy"` +} + +type cliproxyRemoteMgmt struct { + AllowRemote bool `yaml:"allow-remote"` + DisableControlPanel bool `yaml:"disable-control-panel"` +} + +type cliproxyAmpcode struct { + UpstreamURL string `yaml:"upstream-url"` + RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost"` +} + +type cliproxyQuotaExceeded struct { + SwitchProject bool `yaml:"switch-project"` + SwitchPreviewModel bool `yaml:"switch-preview-model"` +} + +type cliproxyAPIKey struct { + APIKey string `yaml:"api-key"` + BaseURL string `yaml:"base-url,omitempty"` + Priority int `yaml:"priority,omitempty"` +} + +type cliproxyOpenAICompat struct { + Name string `yaml:"name"` + BaseURL string `yaml:"base-url"` + APIKeyEntries []cliproxyAPIKey `yaml:"api-key-entries,omitempty"` +} + +// writeConfig generates a CLIProxyAPIPlus YAML config file from AppConfig +func (gw *ProviderGateway) writeConfig(appCfg *AppConfig) (string, error) { + authDir := ExpandHome(appCfg.AuthDir) + if err := os.MkdirAll(authDir, 0700); err != nil { + return "", fmt.Errorf("create auth dir: %w", err) + } + + cfg := cliproxyConfig{ + Port: gw.port, + Host: "127.0.0.1", + AuthDir: authDir, + RequestRetry: appCfg.RequestRetry, + RemoteManagement: cliproxyRemoteMgmt{ + DisableControlPanel: true, + }, + Ampcode: cliproxyAmpcode{ + UpstreamURL: appCfg.AmpcodeURL, + RestrictManagementToLocalhost: true, + }, + QuotaExceeded: cliproxyQuotaExceeded{ + SwitchProject: true, + SwitchPreviewModel: true, + }, + } + + if appCfg.RequestTimeoutMins > 0 { + cfg.RequestTimeout = fmt.Sprintf("%dm", appCfg.RequestTimeoutMins) + } + if appCfg.MaxRetryCredentials > 0 { + cfg.MaxRetryCredentials = appCfg.MaxRetryCredentials + } + if appCfg.Routing != "" { + cfg.Routing = &cliproxyRouting{Strategy: appCfg.Routing} + } + + for _, k := range appCfg.AnthropicAPIKeys { + cfg.ClaudeAPIKey = append(cfg.ClaudeAPIKey, cliproxyAPIKey{ + APIKey: k.Key, Priority: k.Priority, + }) + } + for _, k := range appCfg.OpenAIAPIKeys { + cfg.CodexAPIKey = append(cfg.CodexAPIKey, cliproxyAPIKey{ + APIKey: k.Key, BaseURL: k.BaseURL, Priority: k.Priority, + }) + } + for _, k := range appCfg.GeminiAPIKeys { + cfg.GeminiAPIKey = append(cfg.GeminiAPIKey, cliproxyAPIKey{ + APIKey: k.Key, Priority: k.Priority, + }) + } + + for _, compat := range appCfg.OpenAICompatible { + c := cliproxyOpenAICompat{Name: compat.Name, BaseURL: compat.BaseURL} + for _, k := range compat.Keys { + c.APIKeyEntries = append(c.APIKeyEntries, cliproxyAPIKey{APIKey: k.Key}) + } + cfg.OpenAICompatibility = append(cfg.OpenAICompatibility, c) + } + + excluded := make(map[string][]string) + for provider, enabled := range appCfg.Providers { + if !enabled { + if tokenType, ok := providerTokenTypes[provider]; ok { + excluded[tokenType] = []string{"*"} + } + } + } + if len(excluded) > 0 { + cfg.OAuthExcludedModels = excluded + } + + data, err := yaml.Marshal(&cfg) + if err != nil { + return "", fmt.Errorf("marshal cliproxy config: %w", err) + } + + // Write to config dir (survives /tmp cleanup, cleaned up on shutdown) + configDir, err := os.UserConfigDir() + if err != nil { + configDir = authDir + } else { + configDir = filepath.Join(configDir, "amp-proxy") + } + if err := os.MkdirAll(configDir, 0700); err != nil { + return "", fmt.Errorf("create config dir: %w", err) + } + configPath := filepath.Join(configDir, "cliproxy-internal.yaml") + if err := os.WriteFile(configPath, data, 0600); err != nil { + return "", err + } + + slog.Debug("wrote cliproxy config", "path", configPath) + return configPath, nil +} diff --git a/go.mod b/go.mod index 593d941..0f5169f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,52 @@ module github.com/johannhipp/amp-proxy -go 1.21 +go 1.26.0 + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect + github.com/andybalholm/brotli v1.0.6 // indirect + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/gin-gonic/gin v1.10.1 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.17.4 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/refraction-networking/utls v1.8.2 // indirect + github.com/router-for-me/CLIProxyAPI/v6 v6.9.3 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/tiktoken-go/tokenizer v0.7.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/oauth2 v0.30.0 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6341a01 --- /dev/null +++ b/go.sum @@ -0,0 +1,118 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= +github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= +github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ= +github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= +github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/router-for-me/CLIProxyAPI/v6 v6.9.3 h1:uZSwAZ5+HwG4C9ibLKycG4wQn9PXV2mfv/ApUPFoECM= +github.com/router-for-me/CLIProxyAPI/v6 v6.9.3/go.mod h1:7z31hDX1LapETY/Riaxei+UI8/iHuWL0Mrgokp5rqYs= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= +github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw= +github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/main.go b/main.go index 543ba81..a3eaa22 100644 --- a/main.go +++ b/main.go @@ -3,66 +3,218 @@ package main import ( "context" "flag" + "fmt" "log/slog" "os" "os/signal" + "path/filepath" "strconv" "syscall" ) +var version = "dev" + func main() { - logLevel := &slog.LevelVar{} - if os.Getenv("AMP_PROXY_DEBUG") != "" { - logLevel.Set(slog.LevelDebug) + args := os.Args[1:] + cmd := "serve" + if len(args) > 0 && args[0] != "" && args[0][0] != '-' { + cmd = args[0] + args = args[1:] } - slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel}))) - config := NewDefaultConfig() + switch cmd { + case "serve": + cmdServe(args) + case "login": + cmdLogin(args) + case "logout": + cmdLogout(args) + case "status": + cmdStatus(args) + case "version": + fmt.Printf("amp-proxy %s\n", version) + case "help", "-h", "--help": + printUsage() + default: + fmt.Fprintf(os.Stderr, "unknown command: %s\n\n", cmd) + printUsage() + os.Exit(1) + } +} + +func loadConfigFromArgs(name string, args []string) (*AppConfig, error) { + fs := flag.NewFlagSet(name, flag.ExitOnError) + configPath := fs.String("config", "", "Config file path") + fs.Parse(args) + if v := os.Getenv("AMP_PROXY_CONFIG"); v != "" && *configPath == "" { + *configPath = v + } + return LoadConfig(*configPath) +} - // Parse flags - listenPort := flag.Int("port", config.ListenPort, "Port to listen on") - listenAddr := flag.String("addr", config.ListenAddr, "Address to listen on") - ampcodeURL := flag.String("ampcode", config.DefaultTarget, "Ampcode URL (auth, threads, etc.)") - vibeproxyURL := flag.String("vibeproxy", config.VibeProxyTarget, "VibeProxy URL (LLM provider requests)") - flag.Parse() +func cmdServe(args []string) { + fs := flag.NewFlagSet("serve", flag.ExitOnError) + port := fs.Int("port", 18317, "Port to listen on") + addr := fs.String("addr", "127.0.0.1", "Address to listen on") + configPath := fs.String("config", "", "Config file path (default: auto-detect)") + ampcodeURL := fs.String("ampcode", "https://ampcode.com", "Ampcode URL") + debug := fs.Bool("debug", false, "Enable debug logging") + fs.Parse(args) - // Override from environment variables + // Env overrides if v := os.Getenv("LISTEN_PORT"); v != "" { if p, err := strconv.Atoi(v); err == nil { - *listenPort = p + *port = p } } if v := os.Getenv("LISTEN_ADDR"); v != "" { - *listenAddr = v + *addr = v } if v := os.Getenv("AMPCODE_URL"); v != "" { *ampcodeURL = v } - if v := os.Getenv("VIBEPROXY_URL"); v != "" { - *vibeproxyURL = v + if os.Getenv("AMP_PROXY_DEBUG") != "" { + *debug = true + } + if v := os.Getenv("AMP_PROXY_CONFIG"); v != "" && *configPath == "" { + *configPath = v } - config.ExaAPIKey = os.Getenv("EXA_API_KEY") + // Set up logging + logLevel := &slog.LevelVar{} + if *debug { + logLevel.Set(slog.LevelDebug) + } + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel}))) - config.ListenPort = *listenPort - config.ListenAddr = *listenAddr - config.DefaultTarget = *ampcodeURL - config.VibeProxyTarget = *vibeproxyURL + // Load config + cfg, err := LoadConfig(*configPath) + if err != nil { + slog.Error("failed to load config", "error", err) + os.Exit(1) + } - // Update rule targets to use configured URLs - for i := range config.Rules { - switch config.Rules[i].Name { - case "provider-to-vibeproxy": - config.Rules[i].Target = config.VibeProxyTarget - } + // CLI flags override config + if isFlagSet(fs, "port") || os.Getenv("LISTEN_PORT") != "" { + cfg.ListenPort = *port + } + if isFlagSet(fs, "addr") || os.Getenv("LISTEN_ADDR") != "" { + cfg.ListenAddr = *addr + } + if isFlagSet(fs, "ampcode") || os.Getenv("AMPCODE_URL") != "" { + cfg.AmpcodeURL = *ampcodeURL + } + if v := os.Getenv("EXA_API_KEY"); v != "" { + cfg.ExaAPIKey = v } ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - handler := NewProxyHandler(config) + handler := NewProxyHandler(cfg) if err := handler.Start(ctx); err != nil { slog.Error("failed to start proxy", "error", err) os.Exit(1) } } + +func cmdLogin(args []string) { + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "usage: amp-proxy login \n") + fmt.Fprintf(os.Stderr, "providers: claude, openai, gemini, copilot, qwen, antigravity\n") + os.Exit(1) + } + provider := args[0] + + cfg, err := loadConfigFromArgs("login", args[1:]) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) + os.Exit(1) + } + + if err := RunLogin(provider, cfg.AuthDir); err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } +} + +func cmdLogout(args []string) { + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "usage: amp-proxy logout \n") + os.Exit(1) + } + provider := args[0] + + cfg, err := loadConfigFromArgs("logout", args[1:]) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) + os.Exit(1) + } + + if err := RunLogout(provider, cfg.AuthDir); err != nil { + fmt.Fprintf(os.Stderr, "logout failed: %v\n", err) + os.Exit(1) + } +} + +func cmdStatus(args []string) { + cfg, err := loadConfigFromArgs("status", args) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) + os.Exit(1) + } + + RunStatus(cfg.AuthDir) +} + +func printUsage() { + fmt.Fprintf(os.Stderr, `amp-proxy - Smart proxy for Amp CLI with built-in provider auth + +Usage: + amp-proxy [command] [flags] + +Commands: + serve Start the proxy server (default) + login Authenticate with a provider (claude, openai, gemini, copilot, qwen) + logout Remove saved auth for a provider + status Show auth status for all providers + version Print version info + help Show this help + +Serve Flags: + --port Listen port (default: 18317, env: LISTEN_PORT) + --addr Listen address (default: 127.0.0.1, env: LISTEN_ADDR) + --config Config file path (env: AMP_PROXY_CONFIG) + --ampcode Ampcode URL (default: https://ampcode.com, env: AMPCODE_URL) + --debug Enable debug logging (env: AMP_PROXY_DEBUG) + +Config File: + Optional. Auto-detected at: + %s/config.yaml + Override with --config or AMP_PROXY_CONFIG env var. + +Examples: + amp-proxy Start with defaults + amp-proxy login claude Authenticate with Claude + amp-proxy serve --port 9000 Start on custom port + amp-proxy status Show provider auth status +`, defaultConfigDir()) +} + +func defaultConfigDir() string { + dir, err := os.UserConfigDir() + if err != nil { + return "~/.config" + } + return filepath.Join(dir, "amp-proxy") +} + +func isFlagSet(fs *flag.FlagSet, name string) bool { + found := false + fs.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} diff --git a/proxy.go b/proxy.go index 8938999..d03d1e4 100644 --- a/proxy.go +++ b/proxy.go @@ -12,7 +12,6 @@ import ( "net/http" "net/http/httputil" "net/url" - "sort" "strings" "sync/atomic" "time" @@ -20,13 +19,14 @@ import ( // ProxyHandler handles HTTP request routing type ProxyHandler struct { - config *Config - proxies map[string]*httputil.ReverseProxy + config *AppConfig + gateway *ProviderGateway + ampProxy *httputil.ReverseProxy httpClient *http.Client requestID atomic.Uint64 metrics struct { totalRequests atomic.Uint64 - vibeproxyReqs atomic.Uint64 + providerReqs atomic.Uint64 ampcodeReqs atomic.Uint64 remapReqs atomic.Uint64 exaReqs atomic.Uint64 @@ -35,58 +35,40 @@ type ProxyHandler struct { } // NewProxyHandler creates a new proxy handler -func NewProxyHandler(config *Config) *ProxyHandler { +func NewProxyHandler(config *AppConfig) *ProxyHandler { handler := &ProxyHandler{ - config: config, - proxies: make(map[string]*httputil.ReverseProxy), + config: config, httpClient: &http.Client{ - Timeout: 5 * time.Minute, // Long timeout for streaming LLM responses + Timeout: 5 * time.Minute, }, } - // Collect all unique targets - targets := make(map[string]bool) - targets[config.DefaultTarget] = true - targets[config.VibeProxyTarget] = true - for _, rule := range config.Rules { - if rule.Target != "" { - targets[rule.Target] = true - } + // Set up ampcode reverse proxy + ampcodeURL, err := url.Parse(config.AmpcodeURL) + if err != nil { + slog.Error("invalid ampcode URL", "url", config.AmpcodeURL, "error", err) + ampcodeURL, _ = url.Parse("https://ampcode.com") } - for target := range targets { - targetURL, err := url.Parse(target) - if err != nil { - slog.Error("invalid target URL", "target", target, "error", err) - continue - } - proxy := httputil.NewSingleHostReverseProxy(targetURL) - - // For HTTPS targets, set the Host header to the target host - // so TLS and virtual hosting work correctly. - originalDirector := proxy.Director - proxy.Director = func(req *http.Request) { - originalDirector(req) - req.Host = targetURL.Host - } - - // Use a transport that supports HTTPS - proxy.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{}, - ResponseHeaderTimeout: 5 * time.Minute, // Don't wait forever for upstream - IdleConnTimeout: 90 * time.Second, - MaxIdleConnsPerHost: 10, - } - - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - slog.Error("proxy error", "method", r.Method, "path", r.URL.Path, "target", target, "error", err) - handler.metrics.errors.Add(1) - w.WriteHeader(http.StatusBadGateway) - fmt.Fprintf(w, `{"error":"proxy_error","message":"%s"}`, err.Error()) - } - - handler.proxies[target] = proxy + ampProxy := httputil.NewSingleHostReverseProxy(ampcodeURL) + originalDirector := ampProxy.Director + ampProxy.Director = func(req *http.Request) { + originalDirector(req) + req.Host = ampcodeURL.Host + } + ampProxy.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{}, + ResponseHeaderTimeout: 5 * time.Minute, + IdleConnTimeout: 90 * time.Second, + MaxIdleConnsPerHost: 10, } + ampProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + slog.Error("ampcode proxy error", "method", r.Method, "path", r.URL.Path, "error", err) + handler.metrics.errors.Add(1) + w.WriteHeader(http.StatusBadGateway) + fmt.Fprintf(w, `{"error":"proxy_error","message":"%s"}`, err.Error()) + } + handler.ampProxy = ampProxy return handler } @@ -96,7 +78,7 @@ type responseRecorder struct { http.ResponseWriter statusCode int bytes int64 - errBody []byte // capture body for non-2xx responses + errBody []byte } func (rr *responseRecorder) WriteHeader(code int) { @@ -124,7 +106,7 @@ func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { reqID := ph.requestID.Add(1) start := time.Now() - // Health check — don't log + // Health check if r.URL.Path == "/healthz" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -132,13 +114,13 @@ func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Metrics — don't log + // Metrics if r.URL.Path == "/metrics" { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]uint64{ "total_requests": ph.metrics.totalRequests.Load(), - "vibeproxy_requests": ph.metrics.vibeproxyReqs.Load(), + "provider_requests": ph.metrics.providerReqs.Load(), "ampcode_requests": ph.metrics.ampcodeReqs.Load(), "remap_requests": ph.metrics.remapReqs.Load(), "exa_requests": ph.metrics.exaReqs.Load(), @@ -148,24 +130,18 @@ func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } ph.metrics.totalRequests.Add(1) - - // Log full request details ph.logRequest(reqID, r) - // Browser auth paths: redirect to ampcode.com instead of reverse-proxying. - // OAuth flows set state/nonce cookies on the origin domain. Reverse-proxying - // keeps the browser on localhost:18317 but ampcode's callback goes to - // ampcode.com, causing a domain mismatch. A 302 redirect sends the browser - // to ampcode.com directly so the entire OAuth round-trip stays on one domain. + // Auth redirects — send browser directly to ampcode.com if hasPathPrefix(r.URL.Path, "/auth") || hasPathPrefix(r.URL.Path, "/api/auth") { - redirectURL := ph.config.DefaultTarget + r.URL.RequestURI() + redirectURL := ph.config.AmpcodeURL + r.URL.RequestURI() slog.Info("redirect", "reqID", reqID, "from", r.URL.RequestURI(), "to", redirectURL) ph.metrics.ampcodeReqs.Add(1) http.Redirect(w, r, redirectURL, http.StatusFound) return } - // Fake free-tier status so server-side tools (web_search, read_web_page) aren't blocked + // Fake free-tier status if r.URL.Path == "/api/internal" && r.URL.RawQuery == "getUserFreeTierStatus" { slog.Info("fake getUserFreeTierStatus", "reqID", reqID) w.Header().Set("Content-Type", "application/json") @@ -174,77 +150,158 @@ func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Stub server-side tool calls (webSearch2, extractWebPageContent) + // Tool stubs if r.URL.Path == "/api/internal" && (r.URL.RawQuery == "webSearch2" || r.URL.RawQuery == "extractWebPageContent") { ph.handleToolStub(w, r, reqID) return } - // Model remapping: intercept unsupported provider requests and translate them + // Model remapping: intercept unsupported provider requests and translate if provider, ok := isUnsupportedProviderRequest(r); ok { if model, streaming, isGoogle := parseGoogleProviderRequest(r); isGoogle { - mapping, isExplicit := findMapping(model) - slog.Info("remap", "reqID", reqID, "from", provider+"/"+model, "to", mapping.TargetProvider+"/"+mapping.TargetModel) + remap, isExplicit := ph.config.FindModelRemap(model) + slog.Info("remap", "reqID", reqID, "from", provider+"/"+model, "to", remap.Provider+"/"+remap.To) ph.metrics.remapReqs.Add(1) - ph.handleRemappedRequest(w, r, reqID, model, streaming, mapping, isExplicit) + ph.handleRemappedRequest(w, r, reqID, model, streaming, remap, isExplicit) return } - // Non-Google unsupported provider — log warning, fall through to default - slog.Warn("unsupported provider, forwarding to vibeproxy", "reqID", reqID, "provider", provider) + slog.Warn("unsupported provider, forwarding to gateway", "reqID", reqID, "provider", provider) } - target := ph.findTarget(r) + // Provider requests → embedded CLIProxyAPIPlus + if isProviderRequest(r) { + if ph.gateway == nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprint(w, `{"error":"provider_not_ready","message":"Provider gateway is starting up. Please retry."}`) + return + } + ph.metrics.providerReqs.Add(1) - if target == "" { - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, `{"error": "forbidden", "message": "request blocked by proxy rules"}`) - slog.Warn("blocked", "reqID", reqID, "method", r.Method, "path", r.URL.Path) - ph.metrics.errors.Add(1) - return - } + // Strip unsupported OpenAI fields before forwarding + if strings.Contains(r.URL.Path, "/openai/") { + if err := stripOpenAIUnsupportedFields(r); err != nil { + slog.Warn("failed to strip openai fields", "reqID", reqID, "error", err) + } + } - proxy, ok := ph.proxies[target] - if !ok { - w.WriteHeader(http.StatusBadGateway) - fmt.Fprintf(w, `{"error": "bad_gateway", "message": "target not available"}`) - slog.Error("no proxy configured", "reqID", reqID, "target", target) - ph.metrics.errors.Add(1) - return - } + // Strip cache_control from request bodies (prevents 400 via OAuth route, + // but disables Anthropic prompt caching — set strip_cache_control: false to preserve) + if ph.config.StripCacheControl == nil || *ph.config.StripCacheControl { + if r.Method == "POST" && r.Body != nil { + stripCacheControl(r) + } + } - label := "AMPCODE" - if target == ph.config.VibeProxyTarget { - label = "VIBEPROXY" - ph.metrics.vibeproxyReqs.Add(1) - } else { - ph.metrics.ampcodeReqs.Add(1) - } + slog.Info("route", "reqID", reqID, "label", "PROVIDER", "method", r.Method, "path", r.URL.Path) - // Find which rule matched - ruleName := ph.findMatchedRule(r) - slog.Info("route", "reqID", reqID, "label", label, "method", r.Method, "path", r.URL.Path, "target", target, "rule", ruleName) + // Wrap response writer to fix tool name casing in LLM responses. + // Amp's agent modes check tool names case-sensitively (e.g. "Bash" not "bash"), + // but some models return lowercase names. + rewriter := newToolNameRewriter(w) + rec := &responseRecorder{ResponseWriter: rewriter, statusCode: 200} + ph.gateway.ServeHTTP(rec, r) - // Strip unsupported fields from OpenAI requests before forwarding to vibeproxy - if label == "VIBEPROXY" && strings.Contains(r.URL.Path, "/openai/") { - if err := stripOpenAIUnsupportedFields(r); err != nil { - slog.Warn("failed to strip openai fields", "reqID", reqID, "error", err) + elapsed := time.Since(start) + logAttrs := []any{"reqID", reqID, "label", "PROVIDER", "status", rec.statusCode, "statusText", http.StatusText(rec.statusCode), "bytes", rec.bytes, "elapsed", elapsed.Round(time.Millisecond)} + if len(rec.errBody) > 0 { + logAttrs = append(logAttrs, "body", string(rec.errBody)) } + slog.Info("response", logAttrs...) + return } - // Wrap response writer to capture status/size + // Everything else → ampcode.com + ph.metrics.ampcodeReqs.Add(1) + slog.Info("route", "reqID", reqID, "label", "AMPCODE", "method", r.Method, "path", r.URL.Path) + rec := &responseRecorder{ResponseWriter: w, statusCode: 200} - proxy.ServeHTTP(rec, r) + ph.ampProxy.ServeHTTP(rec, r) elapsed := time.Since(start) - logAttrs := []any{"reqID", reqID, "label", label, "status", rec.statusCode, "statusText", http.StatusText(rec.statusCode), "bytes", rec.bytes, "elapsed", elapsed.Round(time.Millisecond)} + logAttrs := []any{"reqID", reqID, "label", "AMPCODE", "status", rec.statusCode, "statusText", http.StatusText(rec.statusCode), "bytes", rec.bytes, "elapsed", elapsed.Round(time.Millisecond)} if len(rec.errBody) > 0 { logAttrs = append(logAttrs, "body", string(rec.errBody)) } slog.Info("response", logAttrs...) } -// stripOpenAIUnsupportedFields removes fields from OpenAI request bodies that -// vibeproxy doesn't support (e.g. stream_options). +// isProviderRequest checks if this is a provider/LLM request +func isProviderRequest(r *http.Request) bool { + return hasPathPrefix(r.URL.Path, "/api/provider") || + hasPathPrefix(r.URL.Path, "/v1") || + hasPathPrefix(r.URL.Path, "/api/v1") +} + +// hasPathPrefix checks if the request path matches a prefix +func hasPathPrefix(path, prefix string) bool { + return path == prefix || strings.HasPrefix(path, prefix+"/") +} + +// stripCacheControl removes cache_control keys from request JSON bodies +// This prevents 400 errors when requests go through the OAuth route +func stripCacheControl(r *http.Request) { + if r.Body == nil { + return + } + ct := r.Header.Get("Content-Type") + if !strings.Contains(ct, "json") { + return + } + + body, err := io.ReadAll(r.Body) + r.Body.Close() + if err != nil { + r.Body = io.NopCloser(bytes.NewReader(body)) + return + } + + var m map[string]interface{} + if err := json.Unmarshal(body, &m); err != nil { + r.Body = io.NopCloser(bytes.NewReader(body)) + return + } + + changed := removeCacheControl(m) + if !changed { + r.Body = io.NopCloser(bytes.NewReader(body)) + return + } + + modified, err := json.Marshal(m) + if err != nil { + r.Body = io.NopCloser(bytes.NewReader(body)) + return + } + r.Body = io.NopCloser(bytes.NewReader(modified)) + r.ContentLength = int64(len(modified)) +} + +// removeCacheControl recursively removes cache_control keys from a JSON structure +func removeCacheControl(v interface{}) bool { + changed := false + switch val := v.(type) { + case map[string]interface{}: + if _, ok := val["cache_control"]; ok { + delete(val, "cache_control") + changed = true + } + for _, child := range val { + if removeCacheControl(child) { + changed = true + } + } + case []interface{}: + for _, item := range val { + if removeCacheControl(item) { + changed = true + } + } + } + return changed +} + +// stripOpenAIUnsupportedFields removes unsupported fields from OpenAI request bodies func stripOpenAIUnsupportedFields(r *http.Request) error { body, err := io.ReadAll(r.Body) r.Body.Close() @@ -253,7 +310,6 @@ func stripOpenAIUnsupportedFields(r *http.Request) error { } var m map[string]any if err := json.Unmarshal(body, &m); err != nil { - // Not JSON, put it back unchanged r.Body = io.NopCloser(bytes.NewReader(body)) return nil } @@ -278,7 +334,7 @@ func stripOpenAIUnsupportedFields(r *http.Request) error { return nil } -// handleToolStub intercepts server-side tool calls and routes them to Exa or returns stubs +// handleToolStub intercepts server-side tool calls func (ph *ProxyHandler) handleToolStub(w http.ResponseWriter, r *http.Request, reqID uint64) { method := r.URL.RawQuery @@ -326,9 +382,7 @@ func (ph *ProxyHandler) handleExaSearch(w http.ResponseWriter, reqID uint64, obj "type": "auto", "num_results": maxResults, "contents": map[string]any{ - "text": map[string]any{ - "max_characters": 10000, - }, + "text": map[string]any{"max_characters": 10000}, }, } @@ -340,7 +394,6 @@ func (ph *ProxyHandler) handleExaSearch(w http.ResponseWriter, reqID uint64, obj return } - // Format results for the CLI's expected schema md := ph.formatSearchResults(result) slog.Info("exa search returned", "reqID", reqID, "resultCount", len(md)) ph.writeToolResult(w, map[string]any{ @@ -354,9 +407,7 @@ func (ph *ProxyHandler) handleExaContents(w http.ResponseWriter, reqID uint64, p exaReq := map[string]any{ "urls": []string{pageURL}, - "text": map[string]any{ - "max_characters": 20000, - }, + "text": map[string]any{"max_characters": 20000}, } result, err := ph.callExa("/contents", exaReq) @@ -369,9 +420,7 @@ func (ph *ProxyHandler) handleExaContents(w http.ResponseWriter, reqID uint64, p md := ph.formatContentsResults(result) slog.Info("exa contents returned", "reqID", reqID, "chars", len(md)) - ph.writeToolResult(w, map[string]any{ - "excerpts": []string{md}, - }) + ph.writeToolResult(w, map[string]any{"excerpts": []string{md}}) } func (ph *ProxyHandler) callExa(endpoint string, payload map[string]any) (map[string]any, error) { @@ -424,16 +473,10 @@ func (ph *ProxyHandler) formatSearchResults(result map[string]any) []map[string] title, _ := item["title"].(string) url, _ := item["url"].(string) text, _ := item["text"].(string) - if len(text) > 2000 { text = text[:2000] + "..." } - - out = append(out, map[string]string{ - "title": title, - "url": url, - "text": text, - }) + out = append(out, map[string]string{"title": title, "url": url, "text": text}) } return out } @@ -471,34 +514,25 @@ func (ph *ProxyHandler) formatContentsResults(result map[string]any) string { func (ph *ProxyHandler) writeToolResult(w http.ResponseWriter, result any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]any{ - "ok": true, - "result": result, - }) + json.NewEncoder(w).Encode(map[string]any{"ok": true, "result": result}) } func (ph *ProxyHandler) writeToolError(w http.ResponseWriter, code string, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]any{ - "ok": false, - "error": map[string]string{"code": code, "message": message}, - }) + json.NewEncoder(w).Encode(map[string]any{"ok": false, "error": map[string]string{"code": code, "message": message}}) } // logRequest logs detailed request information func (ph *ProxyHandler) logRequest(reqID uint64, r *http.Request) { slog.Info("request", "reqID", reqID, "method", r.Method, "path", r.URL.Path+formatQuery(r.URL.RawQuery)) - // Log all headers for name, values := range r.Header { for _, v := range values { - // Truncate long values (e.g. auth tokens) but show enough to identify them display := v if len(display) > 120 { display = display[:60] + "..." + display[len(display)-20:] } - // Mask authorization tokens but show the scheme lower := strings.ToLower(name) if lower == "authorization" || lower == "cookie" { if idx := strings.Index(display, " "); idx > 0 { @@ -512,14 +546,10 @@ func (ph *ProxyHandler) logRequest(reqID uint64, r *http.Request) { } } - // Log content length / transfer encoding if r.ContentLength > 0 { slog.Debug("body", "reqID", reqID, "bytes", r.ContentLength, "contentType", r.Header.Get("Content-Type")) - } else if r.ContentLength == -1 && r.Body != nil { - slog.Debug("body", "reqID", reqID, "transfer", "chunked/unknown", "contentType", r.Header.Get("Content-Type")) } - // Log body preview for JSON requests (useful for seeing model names, etc.) if r.Body != nil && r.ContentLength > 0 && r.ContentLength <= 64*1024 && strings.Contains(r.Header.Get("Content-Type"), "json") { body, err := io.ReadAll(r.Body) @@ -529,71 +559,47 @@ func (ph *ProxyHandler) logRequest(reqID uint64, r *http.Request) { preview = preview[:500] + fmt.Sprintf("... (%d bytes total)", len(body)) } slog.Debug("json body", "reqID", reqID, "preview", preview) - // Replace the body so the proxy can still read it r.Body = io.NopCloser(strings.NewReader(string(body))) } } } -// findMatchedRule returns the name of the rule that matched, or "default" -func (ph *ProxyHandler) findMatchedRule(r *http.Request) string { - rules := make([]Rule, len(ph.config.Rules)) - copy(rules, ph.config.Rules) - sort.Slice(rules, func(i, j int) bool { - return rules[i].Priority > rules[j].Priority - }) - for _, rule := range rules { - if rule.Match(r) { - return rule.Name - } +// Start starts the proxy server +func (ph *ProxyHandler) Start(ctx context.Context) error { + // Start the embedded provider gateway + gw, err := NewProviderGateway(ph.config) + if err != nil { + return fmt.Errorf("create provider gateway: %w", err) } - return "default" -} - -// findTarget determines the target URL for a request -func (ph *ProxyHandler) findTarget(r *http.Request) string { - rules := make([]Rule, len(ph.config.Rules)) - copy(rules, ph.config.Rules) - sort.Slice(rules, func(i, j int) bool { - return rules[i].Priority > rules[j].Priority - }) + ph.gateway = gw - for _, rule := range rules { - if rule.Match(r) { - return rule.Target - } + if err := gw.Start(ctx); err != nil { + return fmt.Errorf("start provider gateway: %w", err) } - return ph.config.DefaultTarget -} - -// Start starts the proxy server and blocks until ctx is cancelled, then shuts down gracefully. -func (ph *ProxyHandler) Start(ctx context.Context) error { addr := net.JoinHostPort(ph.config.ListenAddr, fmt.Sprintf("%d", ph.config.ListenPort)) slog.Info("amp-proxy starting", "listen", addr, - "vibeproxy", ph.config.VibeProxyTarget, - "ampcode", ph.config.DefaultTarget, + "ampcode", ph.config.AmpcodeURL, + "provider_gateway", gw.targetURL, + "auth_dir", ExpandHome(ph.config.AuthDir), "exa", ph.config.ExaAPIKey != "", ) - slog.Info("routing rules (highest priority first)") - rules := make([]Rule, len(ph.config.Rules)) - copy(rules, ph.config.Rules) - sort.Slice(rules, func(i, j int) bool { - return rules[i].Priority > rules[j].Priority - }) - for _, rule := range rules { - dest := rule.Target - if dest == "" { - dest = "(blocked)" - } - dest = strings.Replace(dest, "https://", "", 1) - dest = strings.Replace(dest, "http://", "", 1) - slog.Info("rule", "priority", rule.Priority, "name", rule.Name, "target", dest) + slog.Info("routing rules") + slog.Info("rule", "priority", "100", "name", "provider", "match", "/api/provider/*, /v1/*, /api/v1/*", "target", "provider-gateway") + slog.Info("rule", "priority", "90", "name", "auth", "match", "/auth/*, /api/auth/*", "target", "302→ampcode.com") + slog.Info("rule", "priority", "80", "name", "tools", "match", "/api/internal?{web*,get*}", "target", "exa/stub") + slog.Info("rule", "priority", "-", "name", "default", "target", "ampcode.com") + + // Print model remaps + for _, m := range ph.config.ModelRemaps { + slog.Info("model remap", "from", m.From, "to", m.Provider+"/"+m.To) + } + if ph.config.Fallback != nil { + slog.Info("model remap", "from", "(fallback)", "to", ph.config.Fallback.Provider+"/"+ph.config.Fallback.To) } - slog.Info("rule", "priority", "-", "name", "(default)", "target", strings.Replace(ph.config.DefaultTarget, "https://", "", 1)) slog.Info("waiting for requests") @@ -615,8 +621,17 @@ func (ph *ProxyHandler) Start(ctx context.Context) error { return err case <-ctx.Done(): slog.Info("shutting down...") + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + + // Shutdown provider gateway + if ph.gateway != nil { + if err := ph.gateway.Shutdown(shutdownCtx); err != nil { + slog.Error("provider gateway shutdown error", "error", err) + } + } + if err := srv.Shutdown(shutdownCtx); err != nil { return fmt.Errorf("shutdown: %w", err) } diff --git a/proxy_test.go b/proxy_test.go index 7e94eb1..139a5fa 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1,63 +1,19 @@ package main import ( + "encoding/json" "net/http" "net/http/httptest" "testing" ) -func newTestConfig(vibeproxyURL, ampcodeURL string) *Config { - config := NewDefaultConfig() - config.DefaultTarget = ampcodeURL - config.VibeProxyTarget = vibeproxyURL - for i := range config.Rules { - switch config.Rules[i].Name { - case "provider-to-vibeproxy": - config.Rules[i].Target = vibeproxyURL - } - } - return config -} - -func TestProviderRequestsGoToVibeProxy(t *testing.T) { - vibeproxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("vibeproxy")) - })) - defer vibeproxy.Close() - - ampcode := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("ampcode")) - })) - defer ampcode.Close() - - config := newTestConfig(vibeproxy.URL, ampcode.URL) - handler := NewProxyHandler(config) - - paths := []string{ - "/api/provider/anthropic", - "/api/provider/openai/v1", - "/api/provider/google", - "/v1/messages", - "/api/v1/chat/completions", - } - - for _, path := range paths { - req := httptest.NewRequest("POST", "http://localhost"+path, nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - - if w.Body.String() != "vibeproxy" { - t.Errorf("Path %s: expected vibeproxy, got %s (status %d)", path, w.Body.String(), w.Code) - } - } -} - func TestAuthRequestsRedirectToAmpcode(t *testing.T) { - config := NewDefaultConfig() - config.DefaultTarget = "https://ampcode.com" - handler := NewProxyHandler(config) + cfg := DefaultConfig() + cfg.AmpcodeURL = "https://ampcode.com" + handler := &ProxyHandler{ + config: cfg, + httpClient: &http.Client{}, + } tests := []struct { path string @@ -83,109 +39,168 @@ func TestAuthRequestsRedirectToAmpcode(t *testing.T) { } } -func TestNonAuthDefaultRequestsGoToAmpcode(t *testing.T) { - vibeproxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("vibeproxy")) - })) - defer vibeproxy.Close() +func TestHealthCheck(t *testing.T) { + cfg := DefaultConfig() + handler := &ProxyHandler{ + config: cfg, + httpClient: &http.Client{}, + } - ampcode := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("ampcode")) - })) - defer ampcode.Close() + req := httptest.NewRequest("GET", "http://localhost/healthz", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) - config := newTestConfig(vibeproxy.URL, ampcode.URL) - handler := NewProxyHandler(config) + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + if w.Body.String() != `{"status":"ok"}` { + t.Errorf("unexpected body: %s", w.Body.String()) + } +} - paths := []string{ - "/api/internal/github-auth-status", - "/api/internal/github-proxy/repos", +func TestGetUserFreeTierStatus(t *testing.T) { + cfg := DefaultConfig() + handler := &ProxyHandler{ + config: cfg, + httpClient: &http.Client{}, } - for _, path := range paths { - req := httptest.NewRequest("GET", "http://localhost"+path, nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) + req := httptest.NewRequest("GET", "http://localhost/api/internal?getUserFreeTierStatus", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) - if w.Body.String() != "ampcode" { - t.Errorf("Path %s: expected ampcode, got %s (status %d)", path, w.Body.String(), w.Code) - } + if w.Code != http.StatusOK { + t.Errorf("expected 200, got %d", w.Code) + } + if w.Body.String() != `{"ok":true,"result":{"canUseAmpFree":true,"isDailyGrantEnabled":true}}` { + t.Errorf("unexpected body: %s", w.Body.String()) } } -func TestDefaultRequestsGoToAmpcode(t *testing.T) { - vibeproxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("vibeproxy")) - })) - defer vibeproxy.Close() - - ampcode := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("ampcode")) - })) - defer ampcode.Close() - - config := newTestConfig(vibeproxy.URL, ampcode.URL) - handler := NewProxyHandler(config) - - // Everything that's not an LLM provider path goes to ampcode - paths := []string{ - "/api/threads", - "/api/threads/sync", - "/api/session", - "/api/sessions", - "/api/telemetry", - "/api/events", - "/api/attachments", - "/api/durable-thread-workers", - "/api/internal", - "/threads/T-some-thread-id", - "/settings", - } - - for _, path := range paths { - req := httptest.NewRequest("GET", "http://localhost"+path, nil) - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) +func TestIsProviderRequest(t *testing.T) { + tests := []struct { + path string + expect bool + }{ + {"/api/provider/anthropic", true}, + {"/api/provider/openai/v1", true}, + {"/api/provider/google", true}, + {"/v1/messages", true}, + {"/api/v1/chat/completions", true}, + {"/api/threads", false}, + {"/auth/cli-login", false}, + {"/api/internal", false}, + {"/settings", false}, + } - if w.Body.String() != "ampcode" { - t.Errorf("Path %s: expected ampcode, got %s (status %d)", path, w.Body.String(), w.Code) + for _, tt := range tests { + req := httptest.NewRequest("GET", "http://localhost"+tt.path, nil) + got := isProviderRequest(req) + if got != tt.expect { + t.Errorf("isProviderRequest(%s): expected %v, got %v", tt.path, tt.expect, got) } } } -func TestFindTarget(t *testing.T) { - config := NewDefaultConfig() - handler := NewProxyHandler(config) - +func TestStripCacheControl(t *testing.T) { tests := []struct { - path string - expect string + name string + input string + expected string + changed bool }{ - // LLM paths -> vibeproxy - {"/api/provider/anthropic", config.VibeProxyTarget}, - {"/api/provider/openai/v1", config.VibeProxyTarget}, - {"/v1/messages", config.VibeProxyTarget}, - {"/api/v1/chat/completions", config.VibeProxyTarget}, - // Everything else -> ampcode - {"/auth/cli-login", config.DefaultTarget}, - {"/api/auth/sign-in", config.DefaultTarget}, - {"/api/internal/github-auth-status", config.DefaultTarget}, - {"/api/threads", config.DefaultTarget}, - {"/api/session", config.DefaultTarget}, - {"/threads/T-abc-123", config.DefaultTarget}, - {"/settings", config.DefaultTarget}, + { + name: "removes top-level cache_control", + input: `{"model":"test","cache_control":{"type":"ephemeral"},"max_tokens":100}`, + expected: `{"max_tokens":100,"model":"test"}`, + changed: true, + }, + { + name: "removes nested cache_control", + input: `{"messages":[{"role":"user","content":"hi","cache_control":{"type":"ephemeral"}}]}`, + expected: `{"messages":[{"content":"hi","role":"user"}]}`, + changed: true, + }, + { + name: "no change when no cache_control", + input: `{"model":"test","max_tokens":100}`, + expected: `{"model":"test","max_tokens":100}`, + changed: false, + }, } for _, tt := range tests { - req := httptest.NewRequest("GET", "http://localhost"+tt.path, nil) - target := handler.findTarget(req) + t.Run(tt.name, func(t *testing.T) { + var m map[string]interface{} + if err := json.Unmarshal([]byte(tt.input), &m); err != nil { + t.Fatal(err) + } + changed := removeCacheControl(m) + if changed != tt.changed { + t.Errorf("removeCacheControl changed=%v, expected %v", changed, tt.changed) + } + }) + } +} - if target != tt.expect { - t.Errorf("Path %s: expected %s, got %s", tt.path, tt.expect, target) +func TestConfigDefaults(t *testing.T) { + cfg := DefaultConfig() + + if cfg.ListenPort != 18317 { + t.Errorf("expected default port 18317, got %d", cfg.ListenPort) + } + if cfg.ListenAddr != "127.0.0.1" { + t.Errorf("expected default addr 127.0.0.1, got %s", cfg.ListenAddr) + } + if cfg.AmpcodeURL != "https://ampcode.com" { + t.Errorf("expected default ampcode URL, got %s", cfg.AmpcodeURL) + } + if cfg.AuthDir != "~/.cli-proxy-api" { + t.Errorf("expected default auth dir ~/.cli-proxy-api, got %s", cfg.AuthDir) + } + if len(cfg.ModelRemaps) != 4 { + t.Errorf("expected 4 default model remaps, got %d", len(cfg.ModelRemaps)) + } +} + +func TestFindModelRemap(t *testing.T) { + cfg := DefaultConfig() + + tests := []struct { + model string + wantTo string + wantProv string + wantExact bool + }{ + {"gemini-3-flash-preview", "claude-sonnet-4-6", "anthropic", true}, + {"gemini-3-pro", "gpt-5.4", "openai", true}, + {"unknown-model", "claude-sonnet-4-6", "anthropic", false}, + } + + for _, tt := range tests { + remap, exact := cfg.FindModelRemap(tt.model) + if remap.To != tt.wantTo { + t.Errorf("FindModelRemap(%s): expected to=%s, got %s", tt.model, tt.wantTo, remap.To) } + if remap.Provider != tt.wantProv { + t.Errorf("FindModelRemap(%s): expected provider=%s, got %s", tt.model, tt.wantProv, remap.Provider) + } + if exact != tt.wantExact { + t.Errorf("FindModelRemap(%s): expected exact=%v, got %v", tt.model, tt.wantExact, exact) + } + } +} + +func TestExpandEnvVars(t *testing.T) { + t.Setenv("TEST_VAR", "hello") + result := expandEnvVars("key: ${TEST_VAR}") + if result != "key: hello" { + t.Errorf("expected 'key: hello', got %q", result) + } + + // Unset var should remain as-is + result = expandEnvVars("key: ${NONEXISTENT_VAR}") + if result != "key: ${NONEXISTENT_VAR}" { + t.Errorf("expected unchanged, got %q", result) } } diff --git a/remap.go b/remap.go index c14f809..3adec1c 100644 --- a/remap.go +++ b/remap.go @@ -13,31 +13,6 @@ import ( "time" ) -// ModelMapping defines how to remap an unsupported model to a supported one -type ModelMapping struct { - SourcePattern string // glob-style match on source model name - TargetModel string // target model identifier - TargetProvider string // "anthropic" or "openai" - TargetPath string // API path to use on vibeproxy -} - -var modelMappings = []ModelMapping{ - // Gemini flash variants -> Claude Sonnet 4.6 (latest) - {SourcePattern: "gemini-3-flash-preview", TargetModel: "claude-sonnet-4-6", TargetProvider: "anthropic", TargetPath: "/api/provider/anthropic/v1/messages"}, - {SourcePattern: "gemini-3-flash", TargetModel: "claude-sonnet-4-6", TargetProvider: "anthropic", TargetPath: "/api/provider/anthropic/v1/messages"}, - // Gemini pro -> GPT 5.4 - {SourcePattern: "gemini-3-pro", TargetModel: "gpt-5.4", TargetProvider: "openai", TargetPath: "/api/provider/openai/v1/chat/completions"}, - // Gemini pro image -> GPT image - {SourcePattern: "gemini-3-pro-image", TargetModel: "gpt-image-1", TargetProvider: "openai", TargetPath: "/api/provider/openai/v1/chat/completions"}, -} - -// Fallback for any model we don't explicitly map -var fallbackMapping = ModelMapping{ - TargetModel: "claude-sonnet-4-6", - TargetProvider: "anthropic", - TargetPath: "/api/provider/anthropic/v1/messages", -} - // googleModelRegex extracts the model name from Google provider paths var googleModelRegex = regexp.MustCompile(`/api/provider/google/.+/models/([^/:]+):(generateContent|streamGenerateContent)`) @@ -58,7 +33,6 @@ func isUnsupportedProviderRequest(r *http.Request) (provider string, ok bool) { if !strings.HasPrefix(r.URL.Path, "/api/provider/") { return "", false } - // Extract provider name from /api/provider/{provider}/... rest := strings.TrimPrefix(r.URL.Path, "/api/provider/") parts := strings.SplitN(rest, "/", 2) if len(parts) == 0 { @@ -71,22 +45,19 @@ func isUnsupportedProviderRequest(r *http.Request) (provider string, ok bool) { return provider, true } -// findMapping returns the mapping for a given model, or the fallback -func findMapping(model string) (ModelMapping, bool) { - for _, m := range modelMappings { - if m.SourcePattern == model { - return m, true - } - } - return fallbackMapping, false -} - // handleRemappedRequest handles a request that needs model remapping -func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Request, reqID uint64, model string, streaming bool, mapping ModelMapping, isExplicit bool) { +func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Request, reqID uint64, model string, streaming bool, remap ModelRemapConfig, isExplicit bool) { start := time.Now() + if ph.gateway == nil || !ph.gateway.IsReady() { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprint(w, `{"error":"provider_not_ready","message":"Provider gateway is starting up. Please retry."}`) + return + } + if !isExplicit { - slog.Warn("unmapped model, using fallback", "reqID", reqID, "model", model, "targetProvider", mapping.TargetProvider, "targetModel", mapping.TargetModel) + slog.Warn("unmapped model, using fallback", "reqID", reqID, "model", model, "targetProvider", remap.Provider, "targetModel", remap.To) } // Read the original request body @@ -109,11 +80,11 @@ func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Req // Translate based on target provider var translatedBody []byte - switch mapping.TargetProvider { + switch remap.Provider { case "anthropic": - translatedBody, err = translateGoogleToAnthropic(googleReq, mapping.TargetModel, streaming) + translatedBody, err = translateGoogleToAnthropic(googleReq, remap.To, streaming) case "openai": - translatedBody, err = translateGoogleToOpenAI(googleReq, mapping.TargetModel, streaming) + translatedBody, err = translateGoogleToOpenAI(googleReq, remap.To, streaming) default: http.Error(w, `{"error":"unsupported target provider"}`, http.StatusInternalServerError) return @@ -124,15 +95,16 @@ func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Req ph.metrics.errors.Add(1) return } - if err := validateTranslatedToolResults(mapping.TargetProvider, translatedBody); err != nil { + if err := validateTranslatedToolResults(remap.Provider, translatedBody); err != nil { http.Error(w, fmt.Sprintf(`{"error":"translated request validation failed: %s"}`, err.Error()), http.StatusInternalServerError) - slog.Error("translated request validation failed", "reqID", reqID, "provider", mapping.TargetProvider, "error", err) + slog.Error("translated request validation failed", "reqID", reqID, "provider", remap.Provider, "error", err) ph.metrics.errors.Add(1) return } - // Build the outbound request to vibeproxy - targetURL := ph.config.VibeProxyTarget + mapping.TargetPath + // Build the outbound request — route through the embedded provider gateway + targetPath := TargetPathForProvider(remap.Provider) + targetURL := ph.gateway.targetURL + targetPath outReq, err := http.NewRequestWithContext(r.Context(), "POST", targetURL, bytes.NewReader(translatedBody)) if err != nil { http.Error(w, `{"error":"failed to create request"}`, http.StatusInternalServerError) @@ -147,7 +119,6 @@ func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Req for _, cookie := range r.Cookies() { outReq.AddCookie(cookie) } - // Forward Amp-specific headers for name, values := range r.Header { if strings.HasPrefix(name, "X-Amp-") { for _, v := range values { @@ -157,23 +128,24 @@ func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Req } // Set Anthropic-specific headers - if mapping.TargetProvider == "anthropic" { + if remap.Provider == "anthropic" { outReq.Header.Set("Anthropic-Version", "2023-06-01") } - // Log the translated body for debugging if len(translatedBody) <= 2000 { slog.Debug("remap translated body", "reqID", reqID, "body", string(translatedBody)) } else { slog.Debug("remap translated body", "reqID", reqID, "body", string(translatedBody[:1000])+"...", "totalBytes", len(translatedBody)) } - slog.Info("remap request", "reqID", reqID, "targetURL", targetURL, "model", mapping.TargetModel, "stream", streaming) + slog.Info("remap request", "reqID", reqID, "targetURL", targetURL, "model", remap.To, "stream", streaming) // Make the request - resp, err := ph.httpClient.Do(outReq) + resp, err := ph.gateway.Do(outReq) if err != nil { - http.Error(w, fmt.Sprintf(`{"error":"upstream request failed: %s"}`, err.Error()), http.StatusBadGateway) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadGateway) + fmt.Fprintf(w, `{"error":"provider_unavailable","message":"Provider backend error: %s"}`, err.Error()) slog.Error("upstream request failed", "reqID", reqID, "error", err) ph.metrics.errors.Add(1) return @@ -181,9 +153,9 @@ func (ph *ProxyHandler) handleRemappedRequest(w http.ResponseWriter, r *http.Req defer resp.Body.Close() if !streaming { - ph.handleNonStreamingResponse(w, resp, reqID, mapping, start) + ph.handleNonStreamingResponse(w, resp, reqID, remap, start) } else { - ph.handleStreamingResponse(w, resp, reqID, mapping, start) + ph.handleStreamingResponse(w, resp, reqID, remap, start) } } @@ -279,14 +251,13 @@ func validateOpenAIToolResults(translatedReq map[string]interface{}) error { return nil } -func (ph *ProxyHandler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, reqID uint64, mapping ModelMapping, start time.Time) { +func (ph *ProxyHandler) handleNonStreamingResponse(w http.ResponseWriter, resp *http.Response, reqID uint64, remap ModelRemapConfig, start time.Time) { respBody, err := io.ReadAll(resp.Body) if err != nil { http.Error(w, `{"error":"failed to read upstream response"}`, http.StatusBadGateway) return } - // If upstream returned an error, log it and pass it through if resp.StatusCode >= 400 { slog.Error("upstream error", "reqID", reqID, "status", resp.StatusCode, "body", string(respBody)) ph.metrics.errors.Add(1) @@ -303,7 +274,6 @@ func (ph *ProxyHandler) handleNonStreamingResponse(w http.ResponseWriter, resp * var upstreamResp map[string]interface{} if err := json.Unmarshal(respBody, &upstreamResp); err != nil { - // Can't parse — just forward raw w.Header().Set("Content-Type", "application/json") w.WriteHeader(resp.StatusCode) w.Write(respBody) @@ -311,14 +281,13 @@ func (ph *ProxyHandler) handleNonStreamingResponse(w http.ResponseWriter, resp * } var translated []byte - switch mapping.TargetProvider { + switch remap.Provider { case "anthropic": translated, err = translateAnthropicToGoogle(upstreamResp) case "openai": translated, err = translateOpenAIToGoogle(upstreamResp) } if err != nil { - // Fallback: forward raw response slog.Warn("response translation failed, forwarding raw", "reqID", reqID, "error", err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(resp.StatusCode) @@ -332,7 +301,7 @@ func (ph *ProxyHandler) handleNonStreamingResponse(w http.ResponseWriter, resp * slog.Info("response remap", "reqID", reqID, "status", 200, "statusText", "OK", "bytes", len(translated), "elapsed", time.Since(start).Round(time.Millisecond)) } -func (ph *ProxyHandler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, reqID uint64, mapping ModelMapping, start time.Time) { +func (ph *ProxyHandler) handleStreamingResponse(w http.ResponseWriter, resp *http.Response, reqID uint64, remap ModelRemapConfig, start time.Time) { if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) w.WriteHeader(resp.StatusCode) @@ -357,7 +326,6 @@ func (ph *ProxyHandler) handleStreamingResponse(w http.ResponseWriter, resp *htt for scanner.Scan() { line := scanner.Text() - // Track event type for Anthropic SSE (event: xxx) if strings.HasPrefix(line, "event: ") { eventType = strings.TrimPrefix(line, "event: ") continue @@ -368,7 +336,6 @@ func (ph *ProxyHandler) handleStreamingResponse(w http.ResponseWriter, resp *htt } data := strings.TrimPrefix(line, "data: ") - // OpenAI stream terminator if data == "[DONE]" { break } @@ -380,7 +347,7 @@ func (ph *ProxyHandler) handleStreamingResponse(w http.ResponseWriter, resp *htt var googleChunk []byte var err error - switch mapping.TargetProvider { + switch remap.Provider { case "anthropic": googleChunk, err = translateAnthropicStreamChunk(chunk, eventType) case "openai": diff --git a/tool_name_rewriter.go b/tool_name_rewriter.go new file mode 100644 index 0000000..6636a15 --- /dev/null +++ b/tool_name_rewriter.go @@ -0,0 +1,79 @@ +package main + +import ( + "bytes" + "log/slog" + "net/http" +) + +// ampToolNames maps lowercase tool names to the casing Amp expects. +// Amp's agent modes define includeTools with specific casing (e.g. "Bash", not "bash"). +// The isToolAllowed check is case-sensitive, so if the LLM returns a tool name +// with different casing, it gets rejected with "tool X is not allowed for Y mode". +var ampToolNames = map[string]string{ + "bash": "Bash", + "read": "Read", + "grep": "Grep", + "task": "Task", +} + +// toolNameReplacements are precomputed byte patterns for efficient replacement +// in SSE stream chunks. Tool names appear in JSON as "name":"toolname" in both +// Anthropic and OpenAI streaming formats. +var toolNameReplacements []struct{ old, new []byte } + +func init() { + for wrong, correct := range ampToolNames { + // Match "name":"bash" patterns (Anthropic SSE content_block_start) + toolNameReplacements = append(toolNameReplacements, + struct{ old, new []byte }{ + old: []byte(`"name":"` + wrong + `"`), + new: []byte(`"name":"` + correct + `"`), + }, + // Also match with space after colon: "name": "bash" + struct{ old, new []byte }{ + old: []byte(`"name": "` + wrong + `"`), + new: []byte(`"name": "` + correct + `"`), + }, + ) + } +} + +// toolNameRewriter wraps an http.ResponseWriter and rewrites tool names +// in the response stream to match Amp's expected casing. +type toolNameRewriter struct { + http.ResponseWriter + rewrites int +} + +func newToolNameRewriter(w http.ResponseWriter) *toolNameRewriter { + return &toolNameRewriter{ResponseWriter: w} +} + +func (rw *toolNameRewriter) Write(b []byte) (int, error) { + out := b + for _, r := range toolNameReplacements { + if bytes.Contains(out, r.old) { + if &out[0] == &b[0] { + // First modification — copy to avoid mutating the original buffer + out = make([]byte, len(b)) + copy(out, b) + } + out = bytes.ReplaceAll(out, r.old, r.new) + rw.rewrites++ + } + } + if rw.rewrites > 0 && &out[0] != &b[0] { + slog.Debug("tool name rewrite", "rewrites", rw.rewrites) + } + // Write the (possibly modified) data, but report original length + // to keep the caller's byte accounting correct. + _, err := rw.ResponseWriter.Write(out) + return len(b), err +} + +func (rw *toolNameRewriter) Flush() { + if f, ok := rw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} diff --git a/tool_name_rewriter_test.go b/tool_name_rewriter_test.go new file mode 100644 index 0000000..ee82bdd --- /dev/null +++ b/tool_name_rewriter_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "net/http/httptest" + "testing" +) + +func TestToolNameRewriter(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "rewrites bash to Bash", + input: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"bash","input":{}}}`, + want: `data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"Bash","input":{}}}`, + }, + { + name: "rewrites with space after colon", + input: `{"name": "bash", "input": {}}`, + want: `{"name": "Bash", "input": {}}`, + }, + { + name: "rewrites read to Read", + input: `{"name":"read"}`, + want: `{"name":"Read"}`, + }, + { + name: "rewrites grep to Grep", + input: `{"name":"grep"}`, + want: `{"name":"Grep"}`, + }, + { + name: "rewrites task to Task", + input: `{"name":"task"}`, + want: `{"name":"Task"}`, + }, + { + name: "does not rewrite already correct casing", + input: `{"name":"Bash"}`, + want: `{"name":"Bash"}`, + }, + { + name: "does not rewrite other tools", + input: `{"name":"edit_file"}`, + want: `{"name":"edit_file"}`, + }, + { + name: "does not rewrite bash in non-name context", + input: `{"cmd":"bash -c echo hello"}`, + want: `{"cmd":"bash -c echo hello"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + rw := newToolNameRewriter(rec) + + n, err := rw.Write([]byte(tt.input)) + if err != nil { + t.Fatalf("Write error: %v", err) + } + if n != len(tt.input) { + t.Errorf("Write returned %d, want %d", n, len(tt.input)) + } + + got := rec.Body.String() + if got != tt.want { + t.Errorf("got %q\nwant %q", got, tt.want) + } + }) + } +}