diff --git a/.claude/rules/02-quality-gates.md b/.claude/rules/02-quality-gates.md index 5e71c4f8f..f8a9ae42c 100644 --- a/.claude/rules/02-quality-gates.md +++ b/.claude/rules/02-quality-gates.md @@ -211,6 +211,7 @@ String containment tests on structured output create false confidence. The test 3. **Existing workflows MUST be confirmed working.** Navigate the dashboard, projects, settings. Verify no regressions — pages load, data displays, navigation works, no new console errors. 4. **New feature/fix MUST be verified on staging.** The specific changes in the PR must work correctly on the live staging environment. 5. **Evidence MUST be reported.** Include screenshots, API responses, or Playwright observations in the PR. +6. **For browser-consumed streams (SSE / WebSocket), verification MUST use a real browser, not `curl`.** `curl` can confirm bytes arrive on the wire; only a browser confirms the client actually dispatches them to its handler. See `.claude/rules/13-staging-verification.md` for the full reasoning and the post-mortem that motivated this rule (`docs/notes/2026-04-19-trial-sse-named-events-postmortem.md`). ### No Exceptions diff --git a/.claude/rules/10-e2e-verification.md b/.claude/rules/10-e2e-verification.md index f74b28b25..d638b7e95 100644 --- a/.claude/rules/10-e2e-verification.md +++ b/.claude/rules/10-e2e-verification.md @@ -23,6 +23,7 @@ Before marking a feature complete: - [ ] The test asserts the **final outcome**, not just that intermediate steps succeeded - [ ] If the test uses mocks at system boundaries, the mock asserts the **exact payload** the real system would receive - [ ] Any untestable gaps are documented with manual verification steps +- [ ] **Port-of-pattern coverage** — when porting a multi-step pattern (VM boot, credential rotation, agent session lifecycle) from an existing consumer to a new one, the new consumer's tests MUST mock each cross-boundary target and assert **every step of the pattern fired** with the correct payload. A test that asserts "step 1 fired" but not "step 3 fired" does not prove the port is complete. See `docs/notes/2026-04-19-trial-orchestrator-agent-boot-postmortem.md` for the class of bug this prevents. ## Data Flow Tracing (Mandatory for Multi-Component Features) diff --git a/.claude/rules/13-staging-verification.md b/.claude/rules/13-staging-verification.md index 162aff63c..10fb976c9 100644 --- a/.claude/rules/13-staging-verification.md +++ b/.claude/rules/13-staging-verification.md @@ -172,6 +172,13 @@ Match the verification to what the PR actually changes: - Verifying no console errors - "The code changes look correct" - "Unit tests pass" +- **For browser-consumed streams (SSE / WebSocket): using `curl` to verify that + bytes arrive on the wire.** Curl confirms the *byte stream*; only a real + browser confirms *dispatch* to the client-side handler (`EventSource.onmessage` + or `WebSocket.onmessage`). A server can emit perfectly valid SSE that the + browser parses and then silently drops because nothing is listening for the + specific event name. See + `docs/notes/2026-04-19-trial-sse-named-events-postmortem.md`. These are baseline regression checks. They do NOT verify that the specific fix or feature works on the live environment. diff --git a/.claude/skills/env-reference/SKILL.md b/.claude/skills/env-reference/SKILL.md index 39a25352e..83545ecb8 100644 --- a/.claude/skills/env-reference/SKILL.md +++ b/.claude/skills/env-reference/SKILL.md @@ -96,6 +96,21 @@ See `apps/api/.env.example` for the full list. Key variables: ### Credential Routes Rate Limits - `RATE_LIMIT_CREDENTIAL_UPDATE` — Applied to both user-scoped (`PUT /api/credentials/agent`) and project-scoped (`PUT /api/projects/:id/credentials`) credential write endpoints (MEDIUM #7 fix) +### Trial Onboarding (`/try` flow) + +See `docs/guides/trial-configuration.md` for the full table with meanings and defaults. Summary: + +- `TRIAL_CLAIM_TOKEN_SECRET` — Worker secret; HMAC key for trial cookies (auto-provisioned by Pulumi) +- `TRIAL_MONTHLY_CAP`, `TRIAL_WORKSPACE_TTL_MS`, `TRIAL_DATA_RETENTION_HOURS` — Global cap + lifetimes +- `TRIAL_ANONYMOUS_USER_ID`, `TRIAL_ANONYMOUS_INSTALLATION_ID` — Sentinel rows for pre-claim ownership +- `TRIAL_AGENT_TYPE_STAGING`, `TRIAL_AGENT_TYPE_PRODUCTION`, `TRIAL_DEFAULT_WORKSPACE_PROFILE` — Agent + profile selection +- `TRIALS_ENABLED_KV_KEY`, `TRIAL_KILL_SWITCH_CACHE_MS` — Kill switch +- `TRIAL_ORCHESTRATOR_OVERALL_TIMEOUT_MS`, `TRIAL_ORCHESTRATOR_STEP_MAX_RETRIES`, `TRIAL_ORCHESTRATOR_RETRY_BASE_DELAY_MS`, `TRIAL_ORCHESTRATOR_RETRY_MAX_DELAY_MS` — Orchestrator retry budget +- `TRIAL_ORCHESTRATOR_NODE_READY_TIMEOUT_MS`, `TRIAL_ORCHESTRATOR_AGENT_READY_TIMEOUT_MS`, `TRIAL_ORCHESTRATOR_WORKSPACE_READY_TIMEOUT_MS`, `TRIAL_ORCHESTRATOR_WORKSPACE_READY_POLL_INTERVAL_MS` — Step-level timeouts +- `TRIAL_VM_SIZE`, `TRIAL_VM_LOCATION` — VM overrides for trial workspaces +- `TRIAL_GITHUB_TIMEOUT_MS` — Per-request timeout for the default-branch probe (`fetchDefaultBranch`); falls back to `main` on timeout/404/error +- `TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS`, `TRIAL_KNOWLEDGE_MAX_EVENTS` — Fast-path knowledge probe tunables + ## VM Agent Environment Variables ### Container/User diff --git a/.github/workflows/deploy-reusable.yml b/.github/workflows/deploy-reusable.yml index 7dbc0510b..9b9951f0c 100644 --- a/.github/workflows/deploy-reusable.yml +++ b/.github/workflows/deploy-reusable.yml @@ -404,6 +404,22 @@ jobs: PULUMI_CONFIG_PASSPHRASE: ${{ secrets.PULUMI_CONFIG_PASSPHRASE }} CLOUDFLARE_API_TOKEN: ${{ secrets.CF_API_TOKEN }} + - name: Enable Trials on Staging (KV kill-switch) + # Trials are gated by a KV flag that fails closed. Staging always enables them + # automatically so the /try flow is live after every deploy. Production stays + # opt-in (operator flips the flag manually when ready to accept live traffic). + if: ${{ inputs.dry_run != true && inputs.environment == 'staging' }} + working-directory: apps/api + env: + CLOUDFLARE_API_TOKEN: ${{ secrets.CF_API_TOKEN }} + CLOUDFLARE_ACCOUNT_ID: ${{ secrets.CF_ACCOUNT_ID }} + run: | + echo "Enabling trials kill-switch on staging (KV trials:enabled=true)" + # `wrangler kv key put` writes to remote by default; --remote is not a valid flag here. + pnpm exec wrangler kv key put "trials:enabled" "true" \ + --binding KV \ + --env staging + - name: Configure Worker Secrets if: ${{ inputs.dry_run != true }} run: bash scripts/deploy/configure-secrets.sh ${{ inputs.environment }} @@ -432,6 +448,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} SMOKE_TEST_AUTH_ENABLED: ${{ secrets.SMOKE_TEST_AUTH_ENABLED }} + ANTHROPIC_API_KEY_TRIAL: ${{ secrets.ANTHROPIC_API_KEY_TRIAL }} SEGMENT_WRITE_KEY: ${{ secrets.SEGMENT_WRITE_KEY }} GA4_API_SECRET: ${{ secrets.GA4_API_SECRET }} GA4_MEASUREMENT_ID: ${{ secrets.GA4_MEASUREMENT_ID }} diff --git a/CLAUDE.md b/CLAUDE.md index 0152ccb6e..38b252eac 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -221,6 +221,10 @@ Domains chain together: competitive research feeds marketing and business strate - Cloudflare D1 (credentials table with AES-GCM encrypted tokens) (028-provider-infrastructure) ## Recent Changes +- ai-proxy-gateway: AI inference proxy routes LLM requests through Cloudflare AI Gateway — `POST /ai/v1/chat/completions` accepts OpenAI-format requests, transparently routes to Workers AI (@cf/* models) or Anthropic (claude-* models) with format translation (`ai-anthropic-translate.ts`); per-user RPM rate limiting + daily token budget via KV; admin model picker at `/admin/ai-proxy`; AI usage analytics dashboard at `/admin/analytics/ai-usage` aggregates AI Gateway logs by model, day, cost; configurable via AI_PROXY_ENABLED, AI_PROXY_DEFAULT_MODEL, AI_GATEWAY_ID, AI_PROXY_ALLOWED_MODELS, AI_PROXY_RATE_LIMIT_RPM, AI_PROXY_RATE_LIMIT_WINDOW_SECONDS, AI_PROXY_MAX_INPUT_TOKENS_PER_REQUEST, AI_USAGE_PAGE_SIZE, AI_USAGE_MAX_PAGES +- trial-agent-boot: TrialOrchestrator `discovery_agent_start` step now runs the full 5-step idempotent VM boot (registers agent session via `createAgentSessionOnNode`, mints MCP token with trialId as synthetic taskId, `startAgentSessionOnNode` with discovery prompt + MCP server URL, drives ACP session `pending → assigned → running`; idempotency flags `mcpToken`, `agentSessionCreatedOnVm`, `agentStartedOnVm`, `acpAssignedOnVm`, `acpRunningOnVm` on DO state let crash/retry resume without double-booking); new `fetchDefaultBranch()` probes GitHub `/repos/:owner/:repo` with AbortController-bounded fetch and threads the real default branch through `projects.defaultBranch` + workspace `git clone --branch` (master-default repos like `octocat/Hello-World` now work); configurable via TRIAL_GITHUB_TIMEOUT_MS (default: 5000); new capability test `apps/api/tests/unit/durable-objects/trial-orchestrator-agent-boot.test.ts` asserts every cross-boundary call fires with correct payload; rule 10 updated with port-of-pattern coverage requirement. See `docs/notes/2026-04-19-trial-orchestrator-agent-boot-postmortem.md`. +- trial-sse-events-fix: Fixed "zero trial.* events on staging" — `formatSse()` in `apps/api/src/routes/trial/events.ts` previously emitted named SSE frames (`event: trial.knowledge\ndata: {...}`), but the frontend subscribes via `source.onmessage` which only fires for the default (unnamed) event; frames arrived on the wire (curl saw them) but browser EventSource silently dropped them. Now emits unnamed `data: {JSON}\n\n` frames; the `TrialEvent` payload's own `type` discriminator preserves dispatch info. Also fixed `eventsUrl` in `apps/api/src/routes/trial/create.ts` response shape mismatch (`/api/trial/events?trialId=X` → `/api/trial/:trialId/events`). New capability test `apps/api/tests/workers/trial-event-bus-sse.test.ts` asserts no `event:` line + JSON round-trip across the TrialEventBus DO → SSE endpoint boundary; unit tests updated to assert new unnamed-frame contract and exact `eventsUrl` shape (no substring matches on URL contracts). Rule 13 updated to ban curl-only verification of browser-consumed SSE/WebSocket streams — curl confirms bytes, browsers confirm dispatch. See `docs/notes/2026-04-19-trial-sse-named-events-postmortem.md`. +- trial-orchestrator-wire-up: TrialOrchestrator Durable Object + GitHub-API knowledge fast-path — `POST /api/trial/create` now fire-and-forget dispatches two concurrent `c.executionCtx.waitUntil` tasks: (1) `env.TRIAL_ORCHESTRATOR.idFromName(trialId)` DO state machine (alarm-driven, steps: project_creation → node_provisioning → workspace_creation → workspace_ready → agent_session → completed; idempotent `start()`; terminal guard on completed/failed; overall-timeout emits `trial.error`); (2) `emitGithubKnowledgeEvents()` probe hits unauthenticated `/repos/:o/:n`, `/repos/:o/:n/languages`, `/repos/:o/:n/readme` in parallel with AbortController-bounded fetches, emits up to `TRIAL_KNOWLEDGE_MAX_EVENTS` `trial.knowledge` events (description, primary language, stars, topics, license, language breakdown by bytes, README first paragraph), swallows all errors; `apps/api/src/services/trial/bridge.ts` bridges ACP session transitions (`running` → `trial.ready`, `failed` → `trial.error`) and MCP tool calls (`add_knowledge` → `trial.knowledge`, `create_idea` → `trial.idea`) into the SSE stream via `readTrialByProject()` KV lookup (no-op on non-trial projects); new sentinel `TRIAL_ANONYMOUS_INSTALLATION_ID` row in `github_installations` so trial projects satisfy the FK; configurable via TRIAL_ORCHESTRATOR_OVERALL_TIMEOUT_MS (default: 300000), TRIAL_ORCHESTRATOR_STEP_MAX_RETRIES (default: 5), TRIAL_ORCHESTRATOR_RETRY_BASE_DELAY_MS (default: 1000), TRIAL_ORCHESTRATOR_RETRY_MAX_DELAY_MS (default: 60000), TRIAL_ORCHESTRATOR_NODE_READY_TIMEOUT_MS (default: 180000), TRIAL_ORCHESTRATOR_AGENT_READY_TIMEOUT_MS (default: 60000), TRIAL_ORCHESTRATOR_WORKSPACE_READY_TIMEOUT_MS (default: 180000), TRIAL_ORCHESTRATOR_WORKSPACE_READY_POLL_INTERVAL_MS (default: 5000), TRIAL_VM_SIZE (default: DEFAULT_VM_SIZE), TRIAL_VM_LOCATION (default: DEFAULT_VM_LOCATION), TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS (default: 5000), TRIAL_KNOWLEDGE_MAX_EVENTS (default: 10) - project-credential-overrides: Per-project agent credential overrides — `credentials.project_id` column (migration 0042, nullable FK to `projects.id ON DELETE CASCADE`) with two partial unique indexes (`WHERE project_id IS NULL` for user-scoped, `WHERE project_id IS NOT NULL` for project-scoped); `getDecryptedAgentKey(db, userId, agentType, key, projectId?)` resolves project → user → platform in order; workspace runtime callback forwards `workspace.projectId`; `CodexRefreshLock` DO preserves scope on OAuth token rotation; new `/api/projects/:id/credentials` routes (GET/PUT/DELETE) guarded by `requireOwnedProject` (404 on cross-user); `ProjectAgentsSection` on Project Settings combines credential override and model/permission override per agent using `AgentKeyCard` (scope='project') with inheritance hints; cross-user writes rejected at query layer AND ownership check; `autoActivate` only affects project-scoped rows (user-scoped untouched) - project-knowledge-graph: Per-project knowledge graph for persistent agent memory — `knowledge_entities`, `knowledge_observations`, `knowledge_relations` tables + FTS5 virtual table in ProjectData DO SQLite (migration 016); entity-observation-relation model with confidence scoring and recency weighting; 11 MCP tools (`add_knowledge`, `update_knowledge`, `remove_knowledge`, `get_knowledge`, `search_knowledge`, `get_project_knowledge`, `get_relevant_knowledge`, `relate_knowledge`, `get_related`, `confirm_knowledge`, `flag_contradiction`) in `apps/api/src/routes/mcp/knowledge-tools.ts`; auto-retrieval of relevant knowledge in `get_instructions` MCP tool; REST API at `/api/projects/:projectId/knowledge/*` for UI CRUD; Knowledge Browser page at `/projects/:id/knowledge` with entity list, search, type filters, detail panel; configurable via KNOWLEDGE_AUTO_RETRIEVE_LIMIT (default: 20), KNOWLEDGE_MAX_ENTITIES_PER_PROJECT (default: 500), KNOWLEDGE_MAX_OBSERVATIONS_PER_ENTITY (default: 100), KNOWLEDGE_SEARCH_LIMIT (default: 20), KNOWLEDGE_SEARCH_MAX_LIMIT (default: 100), KNOWLEDGE_LIST_PAGE_SIZE (default: 50), KNOWLEDGE_LIST_MAX_PAGE_SIZE (default: 200), KNOWLEDGE_OBSERVATION_MAX_LENGTH (default: 1000) - dispatch-task-config-parity: Full task execution config parity for `dispatch_task` MCP tool — extended schema accepts optional `agentProfileId`, `taskMode` (task/conversation), `agentType`, `workspaceProfile` (default/lightweight), `provider` (hetzner/scaleway/gcp), `vmLocation`; config precedence matches normal submit path: explicit field → agent profile → project default → platform default; `resolveAgentProfile()` from `agent-profiles.ts` resolves profiles by ID or name with built-in seeding; profile-derived values (`model`, `permissionMode`, `systemPromptAppend`) passed through to `startTaskRunnerDO()`; `agentProfileHint` and `taskMode` persisted in task INSERT for observability; location validated against resolved provider; `maxTurns`/`timeoutMinutes` excluded — not enforced by runtime (documented in task file) diff --git a/apps/api/.env.example b/apps/api/.env.example index 9157909b7..5e729bf64 100644 --- a/apps/api/.env.example +++ b/apps/api/.env.example @@ -46,6 +46,58 @@ BASE_DOMAIN=workspaces.example.com # MAX_SMOKE_TOKENS_PER_USER=10 # MAX_SMOKE_TOKEN_NAME_LENGTH=100 +# ======================================== +# Trial Onboarding +# ======================================== +# See docs/guides/trial-configuration.md for the full tunable list. +# TRIAL_MONTHLY_CAP=1500 +# TRIAL_WORKSPACE_TTL_MS=1200000 # 20 min +# TRIAL_DATA_RETENTION_HOURS=168 # 7 days +# TRIAL_ANONYMOUS_USER_ID=system_anonymous_trials +# TRIAL_ANONYMOUS_INSTALLATION_ID=1 +# TRIAL_AGENT_TYPE_STAGING=opencode +# TRIAL_AGENT_TYPE_PRODUCTION=claude-code +# TRIAL_DEFAULT_WORKSPACE_PROFILE=lightweight +# TRIAL_VM_SIZE=cax11 +# TRIAL_VM_LOCATION=fsn1 +# TRIALS_ENABLED_KV_KEY=trials:enabled +# TRIAL_REPO_MAX_KB=102400 +# TRIAL_COUNTER_KEEP_MONTHS=6 +# TRIAL_KILL_SWITCH_CACHE_MS=30000 +# TRIAL_GITHUB_TIMEOUT_MS=10000 +# TRIAL_WAITLIST_PURGE_DAYS=90 +# TRIAL_MODEL=@cf/qwen/qwen3-30b-a3b-fp8 # Model for trial conversations +# TRIAL_LLM_PROVIDER=workers-ai # "workers-ai" or "anthropic" +# ANTHROPIC_API_KEY_TRIAL= # Required only when TRIAL_LLM_PROVIDER=anthropic +# ENVIRONMENT=staging # "staging" or "production" — set by deploy pipeline +# TRIAL_SSE_HEARTBEAT_MS=15000 +# TRIAL_SSE_POLL_TIMEOUT_MS=15000 +# TRIAL_SSE_MAX_DURATION_MS=1800000 +# TRIAL_CRON_ROLLOVER_CRON="0 5 1 * *" +# TRIAL_CRON_WAITLIST_CLEANUP="0 4 * * *" +# TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS=10000 +# TRIAL_KNOWLEDGE_MAX_EVENTS=50 +# TRIAL_ORCHESTRATOR_NODE_WAIT_MS=300000 +# TRIAL_ORCHESTRATOR_WORKSPACE_WAIT_MS=300000 +# TRIAL_ORCHESTRATOR_AGENT_BOOT_TIMEOUT_MS=120000 +# TRIAL_ORCHESTRATOR_STEP_RETRY_MAX=3 +# TRIAL_ORCHESTRATOR_STEP_RETRY_DELAY_MS=5000 +# RATE_LIMIT_TRIAL_CREATE=10 +# RATE_LIMIT_TRIAL_SSE=30 + +# ======================================== +# AI Inference Proxy (Cloudflare AI Gateway) +# ======================================== +# AI_PROXY_ENABLED=true +# AI_PROXY_DEFAULT_MODEL=@cf/qwen/qwen3-30b-a3b-fp8 +# AI_GATEWAY_ID=sam +# AI_PROXY_ALLOWED_MODELS= # Comma-separated; defaults in shared/constants +# AI_PROXY_RATE_LIMIT_RPM=60 +# AI_PROXY_RATE_LIMIT_WINDOW_SECONDS=60 +# AI_PROXY_MAX_INPUT_TOKENS_PER_REQUEST=4096 +# AI_USAGE_PAGE_SIZE=50 +# AI_USAGE_MAX_PAGES=20 + # Pages project names (for Worker proxy routing) # PAGES_PROJECT_NAME=sam-web-prod # WWW_PAGES_PROJECT_NAME=sam-www @@ -381,8 +433,8 @@ BASE_DOMAIN=workspaces.example.com # AI Inference Proxy (Workers AI gateway for trial/zero-config users) # AI_PROXY_ENABLED=true # Kill switch: set "false" to disable (default: enabled) -# AI_PROXY_DEFAULT_MODEL=@cf/meta/llama-4-scout-17b-16e-instruct # Default Workers AI model -# AI_PROXY_ALLOWED_MODELS=@cf/meta/llama-4-scout-17b-16e-instruct,@cf/qwen/qwen3-30b-a3b-fp8,@cf/google/gemma-3-12b-it +# AI_PROXY_DEFAULT_MODEL=@cf/meta/llama-4-scout-17b-16e-instruct # Default model (override via admin UI or env var) +# AI_PROXY_ALLOWED_MODELS=@cf/meta/llama-4-scout-17b-16e-instruct,claude-haiku-4-5-20251001,@cf/qwen/qwen3-30b-a3b-fp8,@cf/google/gemma-3-12b-it # AI_PROXY_DAILY_INPUT_TOKEN_LIMIT=500000 # Per-user daily input token cap # AI_PROXY_DAILY_OUTPUT_TOKEN_LIMIT=200000 # Per-user daily output token cap # AI_PROXY_MAX_INPUT_TOKENS_PER_REQUEST=32000 # Max input tokens per single request diff --git a/apps/api/src/db/migrations/0043_trial_foundation.sql b/apps/api/src/db/migrations/0043_trial_foundation.sql new file mode 100644 index 000000000..adba87046 --- /dev/null +++ b/apps/api/src/db/migrations/0043_trial_foundation.sql @@ -0,0 +1,48 @@ +-- Trial onboarding foundation (spec: trial-onboarding-mvp). +-- +-- Seeds a sentinel "system" user that owns anonymous trial projects until a visitor +-- claims them via GitHub OAuth. Using a real row in `users` avoids NULL-handling +-- everywhere that FKs `projects.user_id`. +-- +-- Also creates the waitlist table backing POST /api/trial/waitlist (users who were +-- blocked by the monthly cap). + +-- ============================================================================= +-- System user seed +-- ============================================================================= +-- NOTE: id `system_anonymous_trials` is the sentinel constant exported from +-- @simple-agent-manager/shared as TRIAL_ANONYMOUS_USER_ID. Never log in +-- as this user. `status = 'system'` keeps it out of the 'active' admin +-- list filter; the unfiltered admin query will still surface it and will +-- be filtered in Wave 1+ UI polish. +INSERT INTO users (id, email, email_verified, role, status) +VALUES ( + 'system_anonymous_trials', + 'anonymous-trials@simple-agent-manager.internal', + 0, + 'user', + 'system' +) +ON CONFLICT(id) DO NOTHING; + +-- ============================================================================= +-- Waitlist (cap-exceeded signups) +-- ============================================================================= +-- One row per (email, resetDate). resetDate is the ISO date of the month boundary +-- when the user's queued slot becomes eligible again (YYYY-MM-01, UTC). Once the +-- notification is sent, notified_at is stamped. +CREATE TABLE IF NOT EXISTS trial_waitlist ( + id TEXT PRIMARY KEY NOT NULL, + email TEXT NOT NULL, + submitted_at INTEGER NOT NULL, -- epoch ms + reset_date TEXT NOT NULL, -- 'YYYY-MM-01' (UTC) + notified_at INTEGER -- epoch ms, nullable +); + +-- Dedupe: one queued entry per email per reset window. +CREATE UNIQUE INDEX IF NOT EXISTS idx_trial_waitlist_email_reset + ON trial_waitlist (email, reset_date); + +-- Scan-by-reset-window lookups for the monthly notifier cron. +CREATE INDEX IF NOT EXISTS idx_trial_waitlist_reset_notify + ON trial_waitlist (reset_date, notified_at); diff --git a/apps/api/src/db/migrations/0044_trials_table.sql b/apps/api/src/db/migrations/0044_trials_table.sql new file mode 100644 index 000000000..a020c5679 --- /dev/null +++ b/apps/api/src/db/migrations/0044_trials_table.sql @@ -0,0 +1,60 @@ +-- Trial records (spec: trial-onboarding-mvp, Wave 1 Track A). +-- +-- One row per anonymous trial. The row is created by POST /api/trial/create +-- BEFORE the project is provisioned — projectId is populated by the SSE +-- orchestrator (Track B) once the project row exists in `projects`. +-- +-- Rows are never physically deleted during normal operation (status transitions +-- to 'expired' after TRIAL_WORKSPACE_TTL_MS and to 'claimed' when the visitor +-- completes GitHub OAuth). A daily retention cron deletes rows older than +-- TRIAL_DATA_RETENTION_HOURS in Wave 2. + +-- ============================================================================= +-- trials +-- ============================================================================= +-- Status state machine: +-- pending -> ready | failed | expired +-- ready -> claimed | expired +-- failed -> (terminal) +-- expired -> (terminal; reaped by retention cron) +-- claimed -> (terminal; project is now owned by the claimant user) +CREATE TABLE IF NOT EXISTS trials ( + id TEXT PRIMARY KEY NOT NULL, + -- The anonymous visitor fingerprint (UUID), matches the cookie payload. + -- Different trials for the same visitor share the same fingerprint. + fingerprint TEXT NOT NULL, + -- Canonicalised https://github.com/owner/repo (no .git, no trailing /). + repo_url TEXT NOT NULL, + -- GitHub repo identity, captured at create time for fast status queries + -- and claim validation. + repo_owner TEXT NOT NULL, + repo_name TEXT NOT NULL, + -- 'YYYY-MM' (UTC). Matches TrialCounter keyspace — used for decrement + -- on failure and for monthly rollover audit. + month_key TEXT NOT NULL, + -- pending | ready | failed | expired | claimed + status TEXT NOT NULL DEFAULT 'pending', + -- Populated by the SSE track once the project row exists in `projects`. + -- Nullable: trial can fail before project creation. + project_id TEXT, + -- Populated by POST /api/trial/claim after GitHub OAuth completes. + claimed_by_user_id TEXT, + created_at INTEGER NOT NULL, -- epoch ms + expires_at INTEGER NOT NULL, -- epoch ms (createdAt + TRIAL_WORKSPACE_TTL_MS) + claimed_at INTEGER, -- epoch ms, nullable + -- Free-text error code persisted when status='failed' (e.g. 'repo_not_found'). + error_code TEXT, + error_message TEXT +); + +-- Fast status lookup by fingerprint (e.g. re-opening the trial page). +CREATE INDEX IF NOT EXISTS idx_trials_fingerprint + ON trials (fingerprint, created_at); + +-- Used by the stale-trial expiry cron. +CREATE INDEX IF NOT EXISTS idx_trials_status_expiry + ON trials (status, expires_at); + +-- Used by the monthly rollover audit. +CREATE INDEX IF NOT EXISTS idx_trials_month_key_status + ON trials (month_key, status); diff --git a/apps/api/src/db/migrations/0045_trial_sentinel_installation.sql b/apps/api/src/db/migrations/0045_trial_sentinel_installation.sql new file mode 100644 index 000000000..dc1cf81f0 --- /dev/null +++ b/apps/api/src/db/migrations/0045_trial_sentinel_installation.sql @@ -0,0 +1,27 @@ +-- Trial sentinel GitHub installation (spec: trial-onboarding-mvp). +-- +-- `projects.installation_id` is NOT NULL and FKs into `github_installations` +-- with ON DELETE CASCADE. Anonymous trial projects own no real GitHub App +-- installation, so we seed a sentinel row that lets the FK pass while making +-- its "system-ness" explicit via the account_type / account_name values. +-- +-- The sentinel id is a well-known constant referenced by the TrialOrchestrator +-- DO (see apps/api/src/durable-objects/trial-orchestrator/). The id is also +-- override-able via TRIAL_ANONYMOUS_INSTALLATION_ID env var — if operators +-- choose to override, they are responsible for seeding the target row. + +INSERT INTO github_installations ( + id, + user_id, + installation_id, + account_type, + account_name +) +VALUES ( + 'system_anonymous_trials_installation', + 'system_anonymous_trials', + '0', + 'User', + 'anonymous-trials' +) +ON CONFLICT(id) DO NOTHING; diff --git a/apps/api/src/db/migrations/0046_trial_projects_unique_index_exclude_sentinel.sql b/apps/api/src/db/migrations/0046_trial_projects_unique_index_exclude_sentinel.sql new file mode 100644 index 000000000..32edf807a --- /dev/null +++ b/apps/api/src/db/migrations/0046_trial_projects_unique_index_exclude_sentinel.sql @@ -0,0 +1,24 @@ +-- Exclude the trial-sentinel owner from the (user_id, installation_id, repository) +-- uniqueness invariant so multiple trials can explore the same public repo. +-- +-- Context: trial-onboarding-mvp spec provisions one projects row per trial, all +-- owned by the sentinel user `system_anonymous_trials` and the sentinel +-- installation `system_anonymous_trials_installation` (migrations 0043 + 0045). +-- The original `idx_projects_user_installation_repository` unique index was +-- designed for real users (one user + one GitHub installation + one repo = one +-- project). Applied to the sentinel, it permits at most ONE lifetime trial per +-- repository — the second trial on `octocat/Hello-World` hits a UNIQUE +-- constraint failure inside handleProjectCreation and the TrialOrchestrator DO +-- gives up after 6 alarm retries (user sees "Creating your project 10%" repeat +-- then a generic step_failed error). +-- +-- Fix: keep the uniqueness invariant for real users (defensive against dup +-- project rows) but carve out the sentinel user explicitly. Trial rows are +-- still isolated by `projectId` (enforced in all trial-aware query paths — +-- see helpers.ts:resolveAnonymousUserId security note). + +DROP INDEX IF EXISTS idx_projects_user_installation_repository; + +CREATE UNIQUE INDEX idx_projects_user_installation_repository + ON projects (user_id, installation_id, repository) + WHERE user_id != 'system_anonymous_trials'; diff --git a/apps/api/src/db/schema.ts b/apps/api/src/db/schema.ts index bd3aac82c..2f1bbf02f 100644 --- a/apps/api/src/db/schema.ts +++ b/apps/api/src/db/schema.ts @@ -295,11 +295,15 @@ export const projects = sqliteTable( table.userId, table.normalizedName ), - userInstallationRepoUnique: uniqueIndex('idx_projects_user_installation_repository').on( - table.userId, - table.installationId, - table.repository - ), + // Trial-onboarding sentinel owner is excluded from the uniqueness invariant: + // every anonymous trial inserts a row with the same (sentinel user, sentinel + // installation), so "one project per user+installation+repo" cannot hold for + // the sentinel. Isolation for trial rows is enforced by `projectId` scoping + // (see helpers.ts:resolveAnonymousUserId). The index still enforces + // uniqueness for real users. See migration 0046. + userInstallationRepoUnique: uniqueIndex('idx_projects_user_installation_repository') + .on(table.userId, table.installationId, table.repository) + .where(sql`user_id != 'system_anonymous_trials'`), userGithubRepoIdUnique: uniqueIndex('idx_projects_user_github_repo_id') .on(table.userId, table.githubRepoId) .where(sql`github_repo_id IS NOT NULL`), @@ -1303,3 +1307,92 @@ export const userQuotas = sqliteTable('user_quotas', { }); export type UserQuotaRow = typeof userQuotas.$inferSelect; + +// ============================================================================= +// Trial Onboarding — Waitlist +// ============================================================================= + +/** + * Cap-exceeded trial signups. One row per (email, resetDate). The monthly + * notifier cron marks `notifiedAt` when the reset window opens and an email + * is sent. See docs/guides/trial-configuration.md. + */ +export const trialWaitlist = sqliteTable( + 'trial_waitlist', + { + id: text('id').primaryKey(), + email: text('email').notNull(), + submittedAt: integer('submitted_at').notNull(), // epoch ms + resetDate: text('reset_date').notNull(), // 'YYYY-MM-01' UTC + notifiedAt: integer('notified_at'), // epoch ms, nullable + }, + (table) => ({ + emailResetIdx: uniqueIndex('idx_trial_waitlist_email_reset').on( + table.email, + table.resetDate + ), + resetNotifyIdx: index('idx_trial_waitlist_reset_notify').on( + table.resetDate, + table.notifiedAt + ), + }) +); + +export type TrialWaitlistRow = typeof trialWaitlist.$inferSelect; +export type NewTrialWaitlistRow = typeof trialWaitlist.$inferInsert; + +// ============================================================================= +// Trial Onboarding — Trial records +// ============================================================================= + +/** + * Anonymous trial records. One row per trial lifecycle, created on + * POST /api/trial/create before the project row is provisioned. The + * orchestrator populates `projectId` once provisioning completes. + * + * Status state machine: + * pending -> ready | failed | expired + * ready -> claimed | expired + * failed -> (terminal) + * expired -> (terminal; reaped by retention cron) + * claimed -> (terminal; project.user_id now points to the claimant) + * + * `monthKey` mirrors the TrialCounter DO keyspace ('YYYY-MM' UTC) and is + * used for decrement-on-failure and for the monthly rollover audit. + */ +export const trials = sqliteTable( + 'trials', + { + id: text('id').primaryKey(), + fingerprint: text('fingerprint').notNull(), + repoUrl: text('repo_url').notNull(), + repoOwner: text('repo_owner').notNull(), + repoName: text('repo_name').notNull(), + monthKey: text('month_key').notNull(), + status: text('status').notNull().default('pending'), + projectId: text('project_id'), + claimedByUserId: text('claimed_by_user_id'), + createdAt: integer('created_at').notNull(), // epoch ms + expiresAt: integer('expires_at').notNull(), // epoch ms + claimedAt: integer('claimed_at'), // epoch ms, nullable + errorCode: text('error_code'), + errorMessage: text('error_message'), + }, + (table) => ({ + fingerprintIdx: index('idx_trials_fingerprint').on( + table.fingerprint, + table.createdAt + ), + statusExpiryIdx: index('idx_trials_status_expiry').on( + table.status, + table.expiresAt + ), + monthKeyStatusIdx: index('idx_trials_month_key_status').on( + table.monthKey, + table.status + ), + }) +); + +export type TrialRow = typeof trials.$inferSelect; +export type NewTrialRow = typeof trials.$inferInsert; diff --git a/apps/api/src/durable-objects/project-data/index.ts b/apps/api/src/durable-objects/project-data/index.ts index 19666229a..2e9263c4c 100644 --- a/apps/api/src/durable-objects/project-data/index.ts +++ b/apps/api/src/durable-objects/project-data/index.ts @@ -237,6 +237,32 @@ export class ProjectData extends DurableObject { const projectId = this.getProjectId(); const result = acpSessions.transitionAcpSession(this.sql, sessionId, toStatus, opts, projectId); if (toStatus === 'assigned' || toStatus === 'running') await this.scheduleHeartbeatAlarm(); + + // Trial bridge — fan `running`/`failed` transitions out as trial.ready / + // trial.error SSE events. Non-trial projects short-circuit inside the + // helper after a single KV lookup, so overhead on normal traffic is minimal. + // Fire-and-forget; wrapped in its own try/catch inside the helper. + // + // The local ProjectData `Env` type is a narrow subset (D1 + config knobs), + // but at runtime Cloudflare injects every binding declared in wrangler.toml + // — including KV and TRIAL_EVENT_BUS that the bridge needs. Cast through + // unknown so the DO type stays minimal without leaking worker-scope bindings. + try { + if (projectId) { + const { bridgeAcpSessionTransition } = await import('../../services/trial/bridge'); + const workerEnv = this.env as unknown as import('../../env').Env; + await bridgeAcpSessionTransition(workerEnv, projectId, toStatus, { + errorMessage: opts.errorMessage ?? null, + }); + } + } catch (err) { + log.warn('project_data.trial_bridge_dispatch_failed', { + projectId, + toStatus, + error: err instanceof Error ? err.message : String(err), + }); + } + return result.session; } diff --git a/apps/api/src/durable-objects/trial-counter.ts b/apps/api/src/durable-objects/trial-counter.ts new file mode 100644 index 000000000..eae5ecac7 --- /dev/null +++ b/apps/api/src/durable-objects/trial-counter.ts @@ -0,0 +1,228 @@ +/** + * TrialCounter — global singleton Durable Object for enforcing the monthly + * trial cap. + * + * Keyed by the literal string `global` (one instance per deployment). Callers + * compute the current month key (`YYYY-MM` UTC) and ask the DO to increment + * or decrement the counter atomically. + * + * Storage: embedded SQLite (see wrangler.toml migration tag v7 — + * `new_sqlite_classes = ["TrialCounter"]`). Atomicity is guaranteed by + * `this.ctx.storage.transactionSync` — all reads and writes within the + * callback commit together. + * + * Accessed via: + * const stub = env.TRIAL_COUNTER.get(env.TRIAL_COUNTER.idFromName('global')); + * const result = await stub.increment(monthKey, cap); + */ +import { DurableObject } from 'cloudflare:workers'; + +import type { Env } from '../env'; + +function jsonResponse(body: unknown, status: number): Response { + return new Response(JSON.stringify(body), { + status, + headers: { 'content-type': 'application/json' }, + }); +} + +export interface TrialCounterIncrementResult { + /** true when the increment succeeded (count <= cap after increment) */ + ok: boolean; + /** count after the operation (or current count when cap was hit) */ + count: number; +} + +/** + * Public shape for the Wave-1 fetch API — identical semantics to + * {@link TrialCounterIncrementResult} but uses `allowed` to match the + * HTTP contract documented for the trial create endpoint. + */ +export interface TrialCounterTryIncrementResult { + allowed: boolean; + count: number; +} + +export interface TrialCounterState { + monthKey: string; + count: number; +} + +export class TrialCounter extends DurableObject { + private readonly sql: SqlStorage; + + constructor(ctx: DurableObjectState, env: Env) { + super(ctx, env); + this.sql = ctx.storage.sql; + ctx.blockConcurrencyWhile(async () => { + this.sql.exec( + `CREATE TABLE IF NOT EXISTS trial_counter ( + month_key TEXT PRIMARY KEY NOT NULL, + count INTEGER NOT NULL DEFAULT 0 + )` + ); + }); + } + + /** + * Atomically increment the counter for `monthKey` by 1 and return the new + * value. If the increment would exceed `cap`, the counter is NOT incremented + * and `{ ok: false, count }` is returned. `cap <= 0` disables the limit. + */ + async increment( + monthKey: string, + cap: number + ): Promise { + return this.ctx.storage.transactionSync(() => { + const current = this.readCount(monthKey); + if (cap > 0 && current >= cap) { + return { ok: false, count: current }; + } + const next = current + 1; + this.sql.exec( + `INSERT INTO trial_counter (month_key, count) VALUES (?, ?) + ON CONFLICT(month_key) DO UPDATE SET count = excluded.count`, + monthKey, + next + ); + return { ok: true, count: next }; + }); + } + + /** + * Decrement the counter for `monthKey`. Used when a trial fails AFTER slot + * allocation (so the slot isn't "burned" by a failed provision). Clamps at 0. + */ + async decrement(monthKey: string): Promise { + return this.ctx.storage.transactionSync(() => { + const current = this.readCount(monthKey); + const next = Math.max(0, current - 1); + this.sql.exec( + `INSERT INTO trial_counter (month_key, count) VALUES (?, ?) + ON CONFLICT(month_key) DO UPDATE SET count = excluded.count`, + monthKey, + next + ); + return next; + }); + } + + /** + * Thin alias around {@link increment} that returns the + * `{ allowed, count }` shape used by the Wave-1 HTTP contract. Kept + * separate so call sites that speak the HTTP vocabulary don't have to + * translate `ok` -> `allowed` themselves. + */ + async tryIncrement( + monthKey: string, + cap: number + ): Promise { + const result = await this.increment(monthKey, cap); + return { allowed: result.ok, count: result.count }; + } + + /** + * Delete counter rows whose month keys predate `keepMonthKey` (inclusive + * of `keepMonthKey` is kept). Returns the number of rows deleted. Used + * by the monthly rollover audit cron to bound the DO's SQLite growth. + */ + async prune(keepMonthKey: string): Promise { + return this.ctx.storage.transactionSync(() => { + const before = this.sql + .exec<{ c: number }>( + 'SELECT COUNT(*) AS c FROM trial_counter WHERE month_key < ?', + keepMonthKey + ) + .toArray()[0]?.c ?? 0; + this.sql.exec( + 'DELETE FROM trial_counter WHERE month_key < ?', + keepMonthKey + ); + return before; + }); + } + + /** Read-only accessor for the cap/available display in the UI. */ + async get(monthKey: string): Promise { + const count = this.readCount(monthKey); + return { monthKey, count }; + } + + /** + * Minimal HTTP surface for the DO. Supports: + * GET /state?monthKey=YYYY-MM -> { monthKey, count } + * POST /tryIncrement -> { allowed, count } + * body: { monthKey: string, cap: number } + * POST /decrement -> { count } + * body: { monthKey: string } + * POST /prune -> { deleted } + * body: { keepMonthKey: string } + * + * The fetch handler is useful for Wave-1 call sites that cannot use + * DO RPC directly (e.g. cross-service debug tooling); routes that run + * in the same Worker bundle SHOULD use the typed RPC methods above. + */ + async fetch(request: Request): Promise { + const url = new URL(request.url); + try { + if (request.method === 'GET' && url.pathname === '/state') { + const monthKey = url.searchParams.get('monthKey'); + if (!monthKey) return jsonResponse({ error: 'monthKey required' }, 400); + const state = await this.get(monthKey); + return jsonResponse(state, 200); + } + + if (request.method === 'POST' && url.pathname === '/tryIncrement') { + const body = (await request.json()) as { + monthKey?: string; + cap?: number; + }; + if (!body.monthKey || typeof body.cap !== 'number') { + return jsonResponse({ error: 'monthKey + cap required' }, 400); + } + const result = await this.tryIncrement(body.monthKey, body.cap); + return jsonResponse(result, 200); + } + + if (request.method === 'POST' && url.pathname === '/decrement') { + const body = (await request.json()) as { monthKey?: string }; + if (!body.monthKey) return jsonResponse({ error: 'monthKey required' }, 400); + const count = await this.decrement(body.monthKey); + return jsonResponse({ count }, 200); + } + + if (request.method === 'POST' && url.pathname === '/prune') { + const body = (await request.json()) as { keepMonthKey?: string }; + if (!body.keepMonthKey) { + return jsonResponse({ error: 'keepMonthKey required' }, 400); + } + const deleted = await this.prune(body.keepMonthKey); + return jsonResponse({ deleted }, 200); + } + + return jsonResponse({ error: 'not_found' }, 404); + } catch (err) { + return jsonResponse( + { + error: 'internal', + message: err instanceof Error ? err.message : String(err), + }, + 500 + ); + } + } + + // ------------------------------------------------------------------------- + // Internal helpers + // ------------------------------------------------------------------------- + + private readCount(monthKey: string): number { + const row = this.sql + .exec<{ count: number }>( + 'SELECT count FROM trial_counter WHERE month_key = ? LIMIT 1', + monthKey + ) + .toArray()[0]; + return row?.count ?? 0; + } +} diff --git a/apps/api/src/durable-objects/trial-event-bus.ts b/apps/api/src/durable-objects/trial-event-bus.ts new file mode 100644 index 000000000..84b2fa226 --- /dev/null +++ b/apps/api/src/durable-objects/trial-event-bus.ts @@ -0,0 +1,215 @@ +/** + * TrialEventBus — per-trial Durable Object that buffers `TrialEvent`s and + * supports long-poll subscription for the SSE endpoint. + * + * Keyed by trialId (`env.TRIAL_EVENT_BUS.idFromName(trialId)`). One instance + * per active trial. + * + * Storage model: events are held IN-MEMORY only — the trial lifetime is short + * (<= TRIAL_WORKSPACE_TTL_MS, default 20 min) and the SSE endpoint's long-poll + * loop reconnects on DO eviction. Persisting every token would multiply DO + * storage writes without any recovery benefit for disposable anonymous trials. + * + * Methods (invoked over HTTP from the Worker): + * POST /append body: TrialEvent -> { cursor } append a single event + * GET /poll ?cursor=&timeoutMs= -> { events, cursor, closed } + * POST /close mark the trial stream as finished + * + * Close semantics: once `/close` is called (after trial.error, or when the + * discovery agent session ends), subsequent `/poll` calls return + * `closed: true` and the SSE endpoint finishes the response stream. + * `trial.ready` is NOT terminal — the discovery agent continues producing + * knowledge and idea events after the workspace is provisioned. + */ + +import type { TrialEvent } from '@simple-agent-manager/shared'; +import { DurableObject } from 'cloudflare:workers'; + +import type { Env } from '../env'; +import { log } from '../lib/logger'; + +interface BufferedEvent { + cursor: number; + event: TrialEvent; +} + +interface Waiter { + resolve: () => void; + timer: ReturnType; +} + +const DEFAULT_POLL_TIMEOUT_MS = 15_000; +const MAX_POLL_TIMEOUT_MS = 60_000; +const MAX_BUFFERED_EVENTS = 500; + +export class TrialEventBus extends DurableObject { + private buffer: BufferedEvent[] = []; + private nextCursor = 1; + private closed = false; + private closedLoaded = false; + private waiters: Set = new Set(); + + constructor(ctx: DurableObjectState, env: Env) { + super(ctx, env); + } + + /** + * Lazily load the `closed` flag from DO storage. In-memory state is lost on + * eviction; persisting this single flag ensures SSE consumers see `closed: true` + * even after the DO is re-instantiated (the event buffer is intentionally NOT + * persisted — short-lived trial streams don't need full replay). + */ + private async ensureClosedLoaded(): Promise { + if (this.closedLoaded) return; + const stored = await this.ctx.storage.get('closed'); + if (stored === true) this.closed = true; + this.closedLoaded = true; + } + + async fetch(req: Request): Promise { + const url = new URL(req.url); + const path = url.pathname; + + if (req.method === 'POST' && path === '/append') { + return this.handleAppend(req); + } + if (req.method === 'GET' && path === '/poll') { + return this.handlePoll(url); + } + if (req.method === 'POST' && path === '/close') { + return this.handleClose(); + } + return new Response('not_found', { status: 404 }); + } + + // ------------------------------------------------------------------------- + // Append + // ------------------------------------------------------------------------- + + private async handleAppend(req: Request): Promise { + await this.ensureClosedLoaded(); + log.info('trial_event_bus.handleAppend.enter', { + closed: this.closed, + bufferLen: this.buffer.length, + }); + if (this.closed) { + log.warn('trial_event_bus.handleAppend.rejected_closed', {}); + return new Response( + JSON.stringify({ error: 'closed' }), + { status: 409, headers: { 'content-type': 'application/json' } } + ); + } + let event: TrialEvent; + try { + event = (await req.json()) as TrialEvent; + } catch { + return new Response('bad_json', { status: 400 }); + } + const cursor = this.nextCursor++; + this.buffer.push({ cursor, event }); + log.info('trial_event_bus.handleAppend.stored', { + cursor, + type: event.type, + waiters: this.waiters.size, + }); + if (this.buffer.length > MAX_BUFFERED_EVENTS) { + // Drop oldest events — SSE consumers MUST keep up. + this.buffer.splice(0, this.buffer.length - MAX_BUFFERED_EVENTS); + } + + // Auto-close only on hard-terminal events. `trial.ready` is a milestone + // (workspace provisioned) but the discovery agent continues producing + // `trial.knowledge` and `trial.idea` events afterward — closing here + // would reject those late-arriving events with 409. + if (event.type === 'trial.error') { + this.closed = true; + await this.ctx.storage.put('closed', true); + } + + this.wakeWaiters(); + return Response.json({ cursor }); + } + + // ------------------------------------------------------------------------- + // Poll + // ------------------------------------------------------------------------- + + private async handlePoll(url: URL): Promise { + await this.ensureClosedLoaded(); + const cursor = Number.parseInt(url.searchParams.get('cursor') ?? '0', 10) || 0; + const requestedTimeout = Number.parseInt( + url.searchParams.get('timeoutMs') ?? String(DEFAULT_POLL_TIMEOUT_MS), + 10 + ); + const timeoutMs = Math.min( + Math.max(100, Number.isFinite(requestedTimeout) ? requestedTimeout : DEFAULT_POLL_TIMEOUT_MS), + MAX_POLL_TIMEOUT_MS + ); + + // Immediate return if new events already exist or stream is closed. + const existing = this.bufferAfter(cursor); + if (existing.length > 0 || this.closed) { + return this.buildPollResponse(existing); + } + + // Otherwise wait up to timeoutMs for an append() call or stream closure. + await this.wait(timeoutMs); + const late = this.bufferAfter(cursor); + return this.buildPollResponse(late); + } + + private bufferAfter(cursor: number): BufferedEvent[] { + if (cursor <= 0) return this.buffer.slice(); + // linear scan is fine — buffer is bounded by MAX_BUFFERED_EVENTS. + return this.buffer.filter((e) => e.cursor > cursor); + } + + private buildPollResponse(events: BufferedEvent[]): Response { + const lastCursor = events.length > 0 + ? events[events.length - 1]!.cursor + : (this.buffer.length > 0 ? this.buffer[this.buffer.length - 1]!.cursor : 0); + return Response.json({ + events: events.map((e) => ({ cursor: e.cursor, event: e.event })), + cursor: lastCursor, + closed: this.closed, + }); + } + + // ------------------------------------------------------------------------- + // Close + // ------------------------------------------------------------------------- + + private async handleClose(): Promise { + this.closed = true; + await this.ctx.storage.put('closed', true); + this.wakeWaiters(); + return Response.json({ closed: true }); + } + + // ------------------------------------------------------------------------- + // Wait / wake + // ------------------------------------------------------------------------- + + private wait(timeoutMs: number): Promise { + return new Promise((resolve) => { + const timer = setTimeout(() => { + this.waiters.delete(waiter); + resolve(); + }, timeoutMs); + const waiter: Waiter = { + resolve: () => { + clearTimeout(timer); + resolve(); + }, + timer, + }; + this.waiters.add(waiter); + }); + } + + private wakeWaiters(): void { + const all = Array.from(this.waiters); + this.waiters.clear(); + for (const w of all) w.resolve(); + } +} diff --git a/apps/api/src/durable-objects/trial-orchestrator/helpers.ts b/apps/api/src/durable-objects/trial-orchestrator/helpers.ts new file mode 100644 index 000000000..fabfe6c03 --- /dev/null +++ b/apps/api/src/durable-objects/trial-orchestrator/helpers.ts @@ -0,0 +1,69 @@ +/** + * TrialOrchestrator helper passthroughs. + * + * Backoff/transient-error detection is identical to the TaskRunner pattern, so + * we re-export those helpers rather than duplicating. Anything trial-specific + * (event emission helpers, project-row creation) lives in steps.ts. + */ +export { + computeBackoffMs, + isTransientError, + parseEnvInt, +} from '../task-runner/helpers'; + +import type { TrialEvent } from '@simple-agent-manager/shared'; +import { TRIAL_ANONYMOUS_USER_ID } from '@simple-agent-manager/shared'; + +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { emitTrialEvent } from '../../services/trial/trial-runner'; + +/** + * Sentinel userId used for every anonymous trial project/workspace. + * + * Security note: ALL in-flight trials share the same owning userId. This is + * intentional (we have no real user until `/claim` succeeds), but it means any + * per-user query scope (e.g., `WHERE userId = ?`) does NOT isolate one trial + * from another. Downstream code that acts on trial-owned rows MUST additionally + * scope by `projectId` (or trialId) to prevent one trial's logic from observing + * or mutating another trial's state. Platform-level IDOR enforcement lives in + * `requireOwnedProject` (`apps/api/src/middleware/require-owned-project.ts`) + * and is not bypassed by trial code — but trial-internal loops (e.g., + * `handleNodeSelection` scanning all sentinel-owned nodes) rely on the D1 + * query filtering by projectId/nodeId where applicable. + */ +export function resolveAnonymousUserId(env: Env): string { + return env.TRIAL_ANONYMOUS_USER_ID ?? TRIAL_ANONYMOUS_USER_ID; +} + +/** + * Sentinel GitHub installation id that trial projects FK into. Operators can + * override via `TRIAL_ANONYMOUS_INSTALLATION_ID`; the default value matches + * the row seeded by migration `0045_trial_sentinel_installation.sql`. + */ +export const DEFAULT_TRIAL_ANONYMOUS_INSTALLATION_ID = + 'system_anonymous_trials_installation'; + +export function resolveAnonymousInstallationId(env: Env): string { + return env.TRIAL_ANONYMOUS_INSTALLATION_ID ?? DEFAULT_TRIAL_ANONYMOUS_INSTALLATION_ID; +} + +/** + * Fire-and-forget trial event emit. Wraps `emitTrialEvent` in a catch so a + * failing event bus never blocks the state machine — alarms drive progress. + */ +export async function safeEmitTrialEvent( + env: Env, + trialId: string, + event: TrialEvent +): Promise { + try { + await emitTrialEvent(env, trialId, event); + } catch (err) { + log.warn('trial_orchestrator.emit_event_failed', { + trialId, + type: event.type, + error: err instanceof Error ? err.message : String(err), + }); + } +} diff --git a/apps/api/src/durable-objects/trial-orchestrator/index.ts b/apps/api/src/durable-objects/trial-orchestrator/index.ts new file mode 100644 index 000000000..c81e92751 --- /dev/null +++ b/apps/api/src/durable-objects/trial-orchestrator/index.ts @@ -0,0 +1,402 @@ +/** + * TrialOrchestrator Durable Object — alarm-driven trial provisioning. + * + * Mirrors the TaskRunner DO pattern: one DO instance per trialId, driven by + * `ctx.storage.setAlarm()`, with each step handler idempotent and independent. + * Replaces the earlier fire-and-forget `waitUntil(provisionTrial())` approach + * that couldn't survive Worker restarts or partial failures. + * + * Keyed by trialId: + * env.TRIAL_ORCHESTRATOR.idFromName(trialId) + * + * Step flow (see steps.ts for handlers): + * project_creation → node_selection + * → (healthy existing node) → workspace_creation + * → (no healthy node) → node_provisioning → node_agent_ready → workspace_creation + * → workspace_ready → discovery_agent_start → running + * + * Terminal-for-orchestrator at `running`: the ACP bridge (wired into + * ProjectData DO's `transitionAcpSession`) fires `trial.ready` once the + * discovery agent produces its first assistant turn. + * + * On transient failure: exponential backoff up to configurable max retries. + * On permanent failure: emit trial.error + mark DO completed. + * On overall timeout: emit trial.error with a timeout code. + * + * See: tasks/active/2026-04-19-trial-orchestrator-wire-up.md for full spec. + */ +import { + DEFAULT_TRIAL_ORCHESTRATOR_AGENT_READY_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_HEARTBEAT_SKEW_MS, + DEFAULT_TRIAL_ORCHESTRATOR_NODE_READY_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_OVERALL_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_RETRY_BASE_DELAY_MS, + DEFAULT_TRIAL_ORCHESTRATOR_RETRY_MAX_DELAY_MS, + DEFAULT_TRIAL_ORCHESTRATOR_STEP_MAX_RETRIES, + DEFAULT_TRIAL_ORCHESTRATOR_WORKSPACE_READY_POLL_INTERVAL_MS, + DEFAULT_TRIAL_ORCHESTRATOR_WORKSPACE_READY_TIMEOUT_MS, +} from '@simple-agent-manager/shared'; +import { DurableObject } from 'cloudflare:workers'; + +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { computeBackoffMs, isTransientError, parseEnvInt, safeEmitTrialEvent } from './helpers'; +import { + handleDiscoveryAgentStart, + handleNodeAgentReady, + handleNodeProvisioning, + handleNodeSelection, + handleProjectCreation, + handleRunning, + handleWorkspaceCreation, + handleWorkspaceReady, +} from './steps'; +import type { + StartTrialInput, + TrialOrchestratorContext, + TrialOrchestratorState, + TrialOrchestratorStep, +} from './types'; + +// Re-export public types for consumers +export type { StartTrialInput, TrialOrchestratorState } from './types'; + +export class TrialOrchestrator extends DurableObject { + // ========================================================================= + // Public RPCs (called from Worker routes) + // ========================================================================= + + /** + * Start a new trial orchestration. Called once from POST /api/trial/create + * via `c.executionCtx.waitUntil(stub.start(...))`. Idempotent — if the DO is + * already initialized the call is a no-op. + */ + async start(input: StartTrialInput): Promise { + log.info('trial_orchestrator_do.start.enter', { trialId: input.trialId }); + const existing = await this.getState(); + if (existing) { + log.warn('trial_orchestrator_do.start.already_initialized', { + trialId: input.trialId, + currentStep: existing.currentStep, + }); + return; + } + + const now = Date.now(); + const state: TrialOrchestratorState = { + version: 1, + trialId: input.trialId, + repoUrl: input.repoUrl, + repoOwner: input.repoOwner, + repoName: input.repoName, + currentStep: 'project_creation', + projectId: null, + nodeId: null, + autoProvisionedNode: false, + workspaceId: null, + chatSessionId: null, + acpSessionId: null, + defaultBranch: null, + mcpToken: null, + agentSessionCreatedOnVm: false, + agentStartedOnVm: false, + acpAssignedOnVm: false, + acpRunningOnVm: false, + retryCount: 0, + createdAt: now, + lastStepAt: now, + nodeAgentReadyStartedAt: null, + workspaceReadyStartedAt: null, + completed: false, + failureReason: null, + }; + + await this.ctx.storage.put('state', state); + log.info('trial_orchestrator_do.start.state_put', { trialId: input.trialId }); + await this.ctx.storage.setAlarm(now); + log.info('trial_orchestrator_do.start.alarm_set', { trialId: input.trialId }); + + // Emit `trial.started` so the SSE stream's first real event confirms the + // orchestrator picked up the trial. The frontend uses this to transition + // out of the "Warming up..." empty state. + await safeEmitTrialEvent(this.env, input.trialId, { + type: 'trial.started', + trialId: input.trialId, + projectId: '', // project not yet created; filled in on `trial.ready` + repoUrl: input.repoUrl, + startedAt: now, + }); + log.info('trial_orchestrator_do.start.trial_started_emitted', { + trialId: input.trialId, + }); + + log.info('trial_orchestrator_do.started', { + trialId: input.trialId, + repoUrl: input.repoUrl, + }); + } + + /** + * Public status accessor for debugging / tests. The DO now stores an + * `mcpToken` (minted for the discovery agent in `handleDiscoveryAgentStart`) + * which is a live bearer credential — it must be redacted before leaving + * the DO boundary so any debug endpoint that surfaces status cannot + * inadvertently leak the token. + */ + async getStatus(): Promise { + const state = await this.getState(); + if (!state) return null; + return { + ...state, + mcpToken: state.mcpToken ? '[redacted]' : null, + }; + } + + // ========================================================================= + // Alarm handler — step dispatch + // ========================================================================= + + async alarm(): Promise { + const state = await this.getState(); + log.info('trial_orchestrator_do.alarm.enter', { + trialId: state?.trialId ?? null, + step: state?.currentStep ?? null, + completed: state?.completed ?? null, + }); + if (!state || state.completed) return; + + // Hard overall-timeout guard — independent of per-step retries. + const overallTimeoutMs = this.getOverallTimeoutMs(); + if (Date.now() - state.createdAt > overallTimeoutMs) { + await this.failTrial(state, `Trial timed out after ${overallTimeoutMs}ms`, 'timeout'); + return; + } + + const rc = this.buildContext(); + const stepStartMs = Date.now(); + + try { + switch (state.currentStep) { + case 'project_creation': + await handleProjectCreation(state, rc); + break; + case 'node_selection': + await handleNodeSelection(state, rc); + break; + case 'node_provisioning': + await handleNodeProvisioning(state, rc); + break; + case 'node_agent_ready': + await handleNodeAgentReady(state, rc); + break; + case 'workspace_creation': + await handleWorkspaceCreation(state, rc); + break; + case 'workspace_ready': + await handleWorkspaceReady(state, rc); + break; + case 'discovery_agent_start': + await handleDiscoveryAgentStart(state, rc); + break; + case 'running': + await handleRunning(state, rc); + return; + case 'succeeded': + case 'failed': + return; + default: + log.error('trial_orchestrator_do.unknown_step', { + trialId: state.trialId, + step: state.currentStep, + }); + await this.failTrial(state, `Unknown step: ${state.currentStep}`, 'internal'); + return; + } + } catch (err) { + const errorMessage = err instanceof Error ? err.message : String(err); + const durationMs = Date.now() - stepStartMs; + + log.error('trial_orchestrator_do.step_error', { + trialId: state.trialId, + step: state.currentStep, + retryCount: state.retryCount, + errorMessage, + durationMs, + }); + + if (isTransientError(err) && state.retryCount < this.getMaxRetries()) { + state.retryCount++; + await this.ctx.storage.put('state', state); + const backoff = computeBackoffMs( + state.retryCount, + this.getRetryBaseDelayMs(), + this.getRetryMaxDelayMs(), + ); + await this.ctx.storage.setAlarm(Date.now() + backoff); + log.info('trial_orchestrator_do.step_retry_scheduled', { + trialId: state.trialId, + step: state.currentStep, + retryCount: state.retryCount, + backoffMs: backoff, + }); + } else { + await this.failTrial(state, errorMessage, 'step_failed'); + } + } + } + + // ========================================================================= + // Terminal failure helper + // ========================================================================= + + private async failTrial( + state: TrialOrchestratorState, + reason: string, + errorCode: string, + ): Promise { + // Revoke MCP token BEFORE marking state as failed so a leaked token from a + // botched/timed-out trial cannot continue hitting MCP endpoints for the + // remainder of its 4-hour TTL (DEFAULT_MCP_TOKEN_TTL_SECONDS). Mirrors the + // TaskRunner failure-path cleanup at + // `task-runner/state-machine.ts:265-275`. Best-effort — log-and-continue + // so a KV hiccup doesn't block the failure emission path. + if (state.mcpToken) { + try { + const { revokeMcpToken } = await import('../../services/mcp-token'); + await revokeMcpToken(this.env.KV, state.mcpToken); + } catch (err) { + log.warn('trial_orchestrator_do.mcp_token_revoke_failed', { + trialId: state.trialId, + error: err instanceof Error ? err.message : String(err), + }); + } + state.mcpToken = null; + } + + state.currentStep = 'failed'; + state.completed = true; + state.failureReason = reason; + await this.ctx.storage.put('state', state); + + log.error('trial_orchestrator_do.failed', { + trialId: state.trialId, + projectId: state.projectId, + workspaceId: state.workspaceId, + reason, + errorCode, + }); + + // `trial.error` is part of the terminal SSE contract — always emit, even + // if the event bus is closed (the bus itself handles the 409). + await safeEmitTrialEvent(this.env, state.trialId, { + type: 'trial.error', + // TrialErrorCode is a shared enum; cast through unknown to preserve + // the string provided by callers without forcing an import cycle. + error: errorCode as never, + message: reason, + at: Date.now(), + }); + } + + // ========================================================================= + // Context builder + // ========================================================================= + + private buildContext(): TrialOrchestratorContext { + return { + env: this.env, + ctx: this.ctx, + advanceToStep: async (state: TrialOrchestratorState, nextStep: TrialOrchestratorStep) => { + state.currentStep = nextStep; + state.retryCount = 0; + state.lastStepAt = Date.now(); + await this.ctx.storage.put('state', state); + await this.ctx.storage.setAlarm(Date.now()); + }, + getOverallTimeoutMs: () => this.getOverallTimeoutMs(), + getRetryBaseDelayMs: () => this.getRetryBaseDelayMs(), + getRetryMaxDelayMs: () => this.getRetryMaxDelayMs(), + getMaxRetries: () => this.getMaxRetries(), + getWorkspaceReadyTimeoutMs: () => this.getWorkspaceReadyTimeoutMs(), + getWorkspaceReadyPollIntervalMs: () => this.getWorkspaceReadyPollIntervalMs(), + getNodeReadyTimeoutMs: () => this.getNodeReadyTimeoutMs(), + getAgentReadyTimeoutMs: () => this.getAgentReadyTimeoutMs(), + getHeartbeatSkewMs: () => this.getHeartbeatSkewMs(), + }; + } + + // ========================================================================= + // State accessor + // ========================================================================= + + private async getState(): Promise { + return (await this.ctx.storage.get('state')) ?? null; + } + + // ========================================================================= + // Tunable knobs — env var override with DEFAULT_* fallback (Principle XI). + // ========================================================================= + + private getOverallTimeoutMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_OVERALL_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_OVERALL_TIMEOUT_MS, + ); + } + + private getRetryBaseDelayMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_RETRY_BASE_DELAY_MS, + DEFAULT_TRIAL_ORCHESTRATOR_RETRY_BASE_DELAY_MS, + ); + } + + private getRetryMaxDelayMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_RETRY_MAX_DELAY_MS, + DEFAULT_TRIAL_ORCHESTRATOR_RETRY_MAX_DELAY_MS, + ); + } + + private getMaxRetries(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_STEP_MAX_RETRIES, + DEFAULT_TRIAL_ORCHESTRATOR_STEP_MAX_RETRIES, + ); + } + + private getWorkspaceReadyTimeoutMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_WORKSPACE_READY_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_WORKSPACE_READY_TIMEOUT_MS, + ); + } + + private getWorkspaceReadyPollIntervalMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_WORKSPACE_READY_POLL_INTERVAL_MS, + DEFAULT_TRIAL_ORCHESTRATOR_WORKSPACE_READY_POLL_INTERVAL_MS, + ); + } + + private getNodeReadyTimeoutMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_NODE_READY_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_NODE_READY_TIMEOUT_MS, + ); + } + + private getAgentReadyTimeoutMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_AGENT_READY_TIMEOUT_MS, + DEFAULT_TRIAL_ORCHESTRATOR_AGENT_READY_TIMEOUT_MS, + ); + } + + private getHeartbeatSkewMs(): number { + return parseEnvInt( + this.env.TRIAL_ORCHESTRATOR_HEARTBEAT_SKEW_MS, + DEFAULT_TRIAL_ORCHESTRATOR_HEARTBEAT_SKEW_MS, + ); + } +} diff --git a/apps/api/src/durable-objects/trial-orchestrator/steps.ts b/apps/api/src/durable-objects/trial-orchestrator/steps.ts new file mode 100644 index 000000000..5d767e9db --- /dev/null +++ b/apps/api/src/durable-objects/trial-orchestrator/steps.ts @@ -0,0 +1,835 @@ +// FILE SIZE EXCEPTION: Trial orchestrator steps — sequential state machine steps that share common context; splitting would obscure the step ordering. See .claude/rules/18-file-size-limits.md +/** + * Step handlers for the TrialOrchestrator DO state machine. + * + * Mirrors TaskRunner's split (node-steps.ts + workspace-steps.ts) but consolidated + * because the trial flow is narrower: no task lifecycle, no attachments, no + * follow-ups. Each handler is idempotent and either advances to the next step + * or schedules another poll via the DO alarm. + * + * Flow: + * project_creation + * → node_selection + * → (warm/existing node healthy) → workspace_creation + * → (no healthy node) → node_provisioning → node_agent_ready → workspace_creation + * → workspace_creation → workspace_ready → discovery_agent_start → running + * + * Event emission: each handler fires a trial.progress event at the start so + * the SSE stream reflects what the orchestrator is doing right now. + */ + +import { + DEFAULT_TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS, + DEFAULT_VM_LOCATION, + DEFAULT_VM_SIZE, +} from '@simple-agent-manager/shared'; +import { drizzle } from 'drizzle-orm/d1'; + +import * as schema from '../../db/schema'; +import { log } from '../../lib/logger'; +import { ulid } from '../../lib/ulid'; +import { signCallbackToken } from '../../services/jwt'; +import { getRuntimeLimits } from '../../services/limits'; +import { generateMcpToken, storeMcpToken } from '../../services/mcp-token'; +import { + createAgentSessionOnNode, + createWorkspaceOnNode, + startAgentSessionOnNode, +} from '../../services/node-agent'; +import { createNodeRecord, provisionNode } from '../../services/nodes'; +import * as projectDataService from '../../services/project-data'; +import { DISCOVERY_PROMPT } from '../../services/trial/discovery-prompt'; +import { startDiscoveryAgent } from '../../services/trial/trial-runner'; +import { readTrial, writeTrial } from '../../services/trial/trial-store'; +import { resolveUniqueWorkspaceDisplayName } from '../../services/workspace-names'; +import { verifyNodeAgentHealthy } from '../task-runner/node-steps'; +import { + parseEnvInt, + resolveAnonymousInstallationId, + resolveAnonymousUserId, + safeEmitTrialEvent, +} from './helpers'; +import type { + TrialOrchestratorContext, + TrialOrchestratorState, +} from './types'; + +/** Default trial workspace profile when TRIAL_DEFAULT_WORKSPACE_PROFILE is unset. */ +const DEFAULT_TRIAL_WORKSPACE_PROFILE = 'lightweight'; +/** Fallback branch when the GitHub default-branch probe fails or times out. */ +const TRIAL_FALLBACK_BRANCH = 'main'; + +// --------------------------------------------------------------------------- +// Helpers — trial KV upkeep +// --------------------------------------------------------------------------- + +/** + * Persist project/workspace ids back into the KV trial record. SSE `/events` + * and `/claim` both read from KV, so the orchestrator MUST mirror its progress + * there. Failures are logged but non-fatal — KV writes are best-effort. + */ +async function syncTrialRecord( + rc: TrialOrchestratorContext, + state: TrialOrchestratorState, + patch: { projectId?: string; workspaceId?: string | null } +): Promise { + try { + const record = await readTrial(rc.env, state.trialId); + if (!record) return; + const updated = { + ...record, + projectId: patch.projectId ?? record.projectId, + workspaceId: patch.workspaceId !== undefined ? patch.workspaceId : record.workspaceId, + }; + await writeTrial(rc.env, updated); + } catch (err) { + log.warn('trial_orchestrator.trial_record_sync_failed', { + trialId: state.trialId, + error: err instanceof Error ? err.message : String(err), + }); + } +} + +/** + * Probe the GitHub public API for the repo's default branch. Returns + * `TRIAL_FALLBACK_BRANCH` ('main') on any failure — timeout, 404, malformed + * JSON, etc. — so master-default repos work out of the box but a transient + * GitHub outage never breaks trial provisioning. + * + * Configurable via `TRIAL_GITHUB_TIMEOUT_MS` (default + * `DEFAULT_TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS`, 5s). + */ +async function fetchDefaultBranch( + owner: string, + repo: string, + env: TrialOrchestratorContext['env'], +): Promise { + const timeoutMs = parseEnvInt( + env.TRIAL_GITHUB_TIMEOUT_MS, + DEFAULT_TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS, + ); + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), timeoutMs); + try { + const res = await fetch(`https://api.github.com/repos/${owner}/${repo}`, { + headers: { + 'user-agent': 'sam-trial-onboarding', + accept: 'application/vnd.github+json', + }, + signal: controller.signal, + }); + if (!res.ok) { + log.warn('trial_orchestrator.default_branch_probe_http_error', { + owner, + repo, + status: res.status, + }); + return TRIAL_FALLBACK_BRANCH; + } + const body = (await res.json()) as { default_branch?: unknown }; + if (typeof body.default_branch === 'string' && body.default_branch.length > 0) { + return body.default_branch; + } + return TRIAL_FALLBACK_BRANCH; + } catch (err) { + log.warn('trial_orchestrator.default_branch_probe_failed', { + owner, + repo, + error: err instanceof Error ? err.message : String(err), + }); + return TRIAL_FALLBACK_BRANCH; + } finally { + clearTimeout(timer); + } +} + +/** + * Build the initial prompt for the discovery agent — the canonical discovery + * prompt prefixed with the repo slug. Trials have no task row in D1, so we + * deliberately do NOT ask the agent to call `get_instructions` (which requires + * a task row and would 500). The discovery prompt itself names the tools it + * needs (`add_knowledge`, `create_idea`). + */ +function buildDiscoveryInitialPrompt(repoOwner: string, repoName: string): string { + const header = `Repository under exploration: \`${repoOwner}/${repoName}\`\n\n`; + return `${header}${DISCOVERY_PROMPT}`; +} + +// --------------------------------------------------------------------------- +// Step: project_creation +// --------------------------------------------------------------------------- + +export async function handleProjectCreation( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + await safeEmitTrialEvent(rc.env, state.trialId, { + type: 'trial.progress', + stage: 'creating_project', + progress: 0.1, + at: Date.now(), + }); + + // Idempotency — a retry after partial progress should pick up the existing row. + if (state.projectId) { + const existing = await rc.env.DATABASE.prepare( + `SELECT id FROM projects WHERE id = ?` + ).bind(state.projectId).first<{ id: string }>(); + if (existing) { + await rc.advanceToStep(state, 'node_selection'); + return; + } + // projectId recorded but row missing — fall through and recreate. + log.warn('trial_orchestrator.project_row_missing', { + trialId: state.trialId, + projectId: state.projectId, + }); + state.projectId = null; + } + + const projectId = ulid(); + const userId = resolveAnonymousUserId(rc.env); + const installationId = resolveAnonymousInstallationId(rc.env); + const now = new Date().toISOString(); + const db = drizzle(rc.env.DATABASE, { schema }); + + // Probe GitHub for the real default branch — falls back to 'main' on any + // failure so master-default repos (e.g. octocat/Hello-World) clone correctly + // without breaking main-default repos when GitHub is unreachable. + const defaultBranch = state.defaultBranch + ?? (await fetchDefaultBranch(state.repoOwner, state.repoName, rc.env)); + // Persist the resolved branch BEFORE the D1 insert so a crash between here + // and the project persist (line ~222) does not cause a retry to re-probe + // GitHub and potentially resolve a different value than what is about to + // land in the D1 projects row. + if (state.defaultBranch !== defaultBranch) { + state.defaultBranch = defaultBranch; + await rc.ctx.storage.put('state', state); + } + + // Normalize the repo name for uniqueness. Two trials on the same repo are + // expected — scope uniqueness by trialId so collisions are impossible. + const rawName = `${state.repoOwner}/${state.repoName}`; + const normalizedName = `trial-${state.trialId.toLowerCase()}`; + + await db.insert(schema.projects).values({ + id: projectId, + userId, + name: rawName, + normalizedName, + installationId, + repository: `${state.repoOwner}/${state.repoName}`, + defaultBranch, + description: 'Anonymous trial — repository exploration', + createdBy: userId, + createdAt: now, + updatedAt: now, + }); + + state.projectId = projectId; + await rc.ctx.storage.put('state', state); + + // Mirror into KV so /claim and /events can resolve the trial before any + // further steps complete. + await syncTrialRecord(rc, state, { projectId }); + + log.info('trial_orchestrator.step.project_created', { + trialId: state.trialId, + projectId, + repo: rawName, + }); + + await rc.advanceToStep(state, 'node_selection'); +} + +// --------------------------------------------------------------------------- +// Step: node_selection +// --------------------------------------------------------------------------- + +export async function handleNodeSelection( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + await safeEmitTrialEvent(rc.env, state.trialId, { + type: 'trial.progress', + stage: 'finding_node', + progress: 0.2, + at: Date.now(), + }); + + // Idempotency: if we already picked a node, verify it's still healthy then + // skip ahead. If a preferred-node survived but isn't healthy anymore we + // fall back to provisioning (trials never pin a specific node). + if (state.nodeId) { + if (await verifyNodeAgentHealthy(state.nodeId, rc as unknown as import('../task-runner/types').TaskRunnerContext)) { + await rc.advanceToStep(state, 'workspace_creation'); + return; + } + log.warn('trial_orchestrator.recorded_node_unhealthy', { + trialId: state.trialId, + nodeId: state.nodeId, + }); + state.nodeId = null; + await rc.ctx.storage.put('state', state); + } + + // Find any running node owned by the sentinel user with capacity. The + // sentinel user never has user-level Hetzner credentials, so in practice + // trial fleets always auto-provision on platform credentials. We still try + // reuse first because warm/multi-trial scenarios benefit from it. + const userId = resolveAnonymousUserId(rc.env); + const existing = await rc.env.DATABASE.prepare( + `SELECT id FROM nodes + WHERE user_id = ? AND status = 'running' AND health_status = 'healthy' + LIMIT 1` + ).bind(userId).first<{ id: string }>(); + + if (existing?.id) { + if (await verifyNodeAgentHealthy(existing.id, rc as unknown as import('../task-runner/types').TaskRunnerContext)) { + state.nodeId = existing.id; + await rc.ctx.storage.put('state', state); + await rc.advanceToStep(state, 'workspace_creation'); + return; + } + } + + await rc.advanceToStep(state, 'node_provisioning'); +} + +// --------------------------------------------------------------------------- +// Step: node_provisioning +// --------------------------------------------------------------------------- + +export async function handleNodeProvisioning( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + await safeEmitTrialEvent(rc.env, state.trialId, { + type: 'trial.progress', + stage: 'provisioning_node', + progress: 0.3, + at: Date.now(), + }); + + // If we already created the node (retry), poll its status. + if (state.nodeId) { + const node = await rc.env.DATABASE.prepare( + `SELECT status, error_message FROM nodes WHERE id = ?` + ).bind(state.nodeId).first<{ status: string; error_message: string | null }>(); + + if (node?.status === 'running') { + await rc.advanceToStep(state, 'node_agent_ready'); + return; + } + if (node?.status === 'error' || node?.status === 'stopped') { + throw Object.assign( + new Error(node.error_message || 'Trial node provisioning failed'), + { permanent: true }, + ); + } + // Still creating — retry via backoff (caller's alarm loop handles the delay). + throw new Error('Node still provisioning — will retry'); + } + + const userId = resolveAnonymousUserId(rc.env); + const limits = getRuntimeLimits(rc.env); + + const vmSize = (rc.env.TRIAL_VM_SIZE as never) ?? DEFAULT_VM_SIZE; + const vmLocation = rc.env.TRIAL_VM_LOCATION ?? DEFAULT_VM_LOCATION; + + const createdNode = await createNodeRecord(rc.env, { + userId, + name: `trial-${state.trialId.slice(-8)}`, + vmSize, + vmLocation, + heartbeatStaleAfterSeconds: limits.nodeHeartbeatStaleSeconds, + }); + + state.nodeId = createdNode.id; + state.autoProvisionedNode = true; + await rc.ctx.storage.put('state', state); + + log.info('trial_orchestrator.step.node_provisioning_started', { + trialId: state.trialId, + nodeId: createdNode.id, + vmSize, + vmLocation, + }); + + // Kick provisioning. We omit taskContext — trial agents don't need the + // VM message reporter wiring that TaskRunner uses for chat persistence, + // because startDiscoveryAgent creates its own chat/ACP sessions later. + // + // Note: provisionNode() may return with the node in 'creating' status for + // async-IP providers (Scaleway) — the VM boots and gets its IP via the + // first heartbeat backfill. Do NOT synchronously require status='running' + // here; the next step (node_agent_ready) polls heartbeat freshness, which + // correctly waits for the VM to boot + heartbeat regardless of provider. + await provisionNode(createdNode.id, rc.env); + + await rc.advanceToStep(state, 'node_agent_ready'); +} + +// --------------------------------------------------------------------------- +// Step: node_agent_ready +// --------------------------------------------------------------------------- + +export async function handleNodeAgentReady( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + if (!state.nodeId) { + throw Object.assign( + new Error('node_agent_ready entered without nodeId'), + { permanent: true }, + ); + } + + if (!state.nodeAgentReadyStartedAt) { + state.nodeAgentReadyStartedAt = Date.now(); + await rc.ctx.storage.put('state', state); + } + + const timeoutMs = rc.getNodeReadyTimeoutMs(); + const elapsed = Date.now() - state.nodeAgentReadyStartedAt; + if (elapsed > timeoutMs) { + throw Object.assign( + new Error(`Trial node agent not ready within ${timeoutMs}ms`), + { permanent: true }, + ); + } + + const node = await rc.env.DATABASE.prepare( + `SELECT health_status, last_heartbeat_at FROM nodes WHERE id = ?` + ).bind(state.nodeId).first<{ health_status: string | null; last_heartbeat_at: string | null }>(); + + if (node?.health_status === 'healthy' && node.last_heartbeat_at) { + const heartbeatTime = new Date(node.last_heartbeat_at).getTime(); + // Require a heartbeat after we started waiting — otherwise a stale warm + // node's heartbeat from a previous boot would satisfy this check. + const skewMs = rc.getHeartbeatSkewMs(); + if (heartbeatTime > state.nodeAgentReadyStartedAt - skewMs) { + await rc.advanceToStep(state, 'workspace_creation'); + return; + } + } + + // Not ready — requeue. We use the DO's existing alarm loop: throwing a + // transient error causes the orchestrator to schedule a backoff retry. + throw new Error('Trial node agent not yet ready — will retry'); +} + +// --------------------------------------------------------------------------- +// Step: workspace_creation +// --------------------------------------------------------------------------- + +export async function handleWorkspaceCreation( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + await safeEmitTrialEvent(rc.env, state.trialId, { + type: 'trial.progress', + stage: 'creating_workspace', + progress: 0.5, + at: Date.now(), + }); + + if (!state.projectId || !state.nodeId) { + throw Object.assign( + new Error('workspace_creation requires projectId and nodeId'), + { permanent: true }, + ); + } + + // Idempotency — if workspace row already exists, just move on. + if (state.workspaceId) { + const existing = await rc.env.DATABASE.prepare( + `SELECT id FROM workspaces WHERE id = ?` + ).bind(state.workspaceId).first<{ id: string }>(); + if (existing) { + await rc.advanceToStep(state, 'workspace_ready'); + return; + } + // Row missing — fall through and recreate. + state.workspaceId = null; + } + + const userId = resolveAnonymousUserId(rc.env); + const installationId = resolveAnonymousInstallationId(rc.env); + const workspaceId = ulid(); + const repository = `${state.repoOwner}/${state.repoName}`; + const displayName = `trial-${state.repoName}`.slice(0, 60); + const db = drizzle(rc.env.DATABASE, { schema }); + const unique = await resolveUniqueWorkspaceDisplayName(db, state.nodeId, displayName); + const now = new Date().toISOString(); + + const profile = rc.env.TRIAL_DEFAULT_WORKSPACE_PROFILE ?? DEFAULT_TRIAL_WORKSPACE_PROFILE; + const vmSize = (rc.env.TRIAL_VM_SIZE as never) ?? DEFAULT_VM_SIZE; + const vmLocation = rc.env.TRIAL_VM_LOCATION ?? DEFAULT_VM_LOCATION; + const branch = state.defaultBranch ?? TRIAL_FALLBACK_BRANCH; + + await db.insert(schema.workspaces).values({ + id: workspaceId, + nodeId: state.nodeId, + projectId: state.projectId, + userId, + installationId, + name: displayName, + displayName: unique.displayName, + normalizedDisplayName: unique.normalizedDisplayName, + repository, + branch, + status: 'creating', + vmSize, + vmLocation, + workspaceProfile: profile, + createdAt: now, + updatedAt: now, + }); + + state.workspaceId = workspaceId; + await rc.ctx.storage.put('state', state); + + await syncTrialRecord(rc, state, { workspaceId }); + + const callbackToken = await signCallbackToken(workspaceId, rc.env); + await createWorkspaceOnNode(state.nodeId, rc.env, userId, { + workspaceId, + repository, + branch, + callbackToken, + lightweight: profile === 'lightweight', + }); + + log.info('trial_orchestrator.step.workspace_creating', { + trialId: state.trialId, + projectId: state.projectId, + workspaceId, + nodeId: state.nodeId, + }); + + await rc.advanceToStep(state, 'workspace_ready'); +} + +// --------------------------------------------------------------------------- +// Step: workspace_ready +// --------------------------------------------------------------------------- + +export async function handleWorkspaceReady( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + await safeEmitTrialEvent(rc.env, state.trialId, { + type: 'trial.progress', + stage: 'starting_agent', + progress: 0.7, + at: Date.now(), + }); + + if (!state.workspaceId) { + throw Object.assign( + new Error('workspace_ready without workspaceId'), + { permanent: true }, + ); + } + + if (!state.workspaceReadyStartedAt) { + state.workspaceReadyStartedAt = Date.now(); + await rc.ctx.storage.put('state', state); + } + + const ws = await rc.env.DATABASE.prepare( + `SELECT status, error_message FROM workspaces WHERE id = ?` + ).bind(state.workspaceId).first<{ status: string; error_message: string | null }>(); + + if (ws?.status === 'running' || ws?.status === 'recovery') { + await rc.advanceToStep(state, 'discovery_agent_start'); + return; + } + if (ws?.status === 'error') { + throw Object.assign( + new Error(ws.error_message || 'Trial workspace creation failed'), + { permanent: true }, + ); + } + + const timeoutMs = rc.getWorkspaceReadyTimeoutMs(); + const elapsed = Date.now() - state.workspaceReadyStartedAt; + if (elapsed > timeoutMs) { + throw Object.assign( + new Error(`Trial workspace did not become ready within ${timeoutMs}ms`), + { permanent: true }, + ); + } + + const pollIntervalMs = rc.getWorkspaceReadyPollIntervalMs(); + const nextPollMs = Math.min(pollIntervalMs, Math.max(timeoutMs - elapsed, 1_000)); + await rc.ctx.storage.setAlarm(Date.now() + nextPollMs); +} + +// --------------------------------------------------------------------------- +// Step: discovery_agent_start +// --------------------------------------------------------------------------- + +export async function handleDiscoveryAgentStart( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + await safeEmitTrialEvent(rc.env, state.trialId, { + type: 'trial.progress', + stage: 'agent_booting', + progress: 0.9, + at: Date.now(), + }); + + if (!state.projectId || !state.workspaceId || !state.nodeId) { + throw Object.assign( + new Error('discovery_agent_start requires projectId, workspaceId, and nodeId'), + { permanent: true }, + ); + } + + const userId = resolveAnonymousUserId(rc.env); + const projectId = state.projectId; + const workspaceId = state.workspaceId; + const nodeId = state.nodeId; + + // Step 1: create the chat + ACP session rows in ProjectData DO (idempotent + // via state.chatSessionId / state.acpSessionId flags). + let chatSessionId = state.chatSessionId; + let acpSessionId = state.acpSessionId; + let agentType: string | null = null; + if (!chatSessionId || !acpSessionId) { + const res = await startDiscoveryAgent(rc.env, { + projectId, + workspaceId, + sessionTopic: `${state.repoOwner}/${state.repoName}`, + }); + chatSessionId = res.chatSessionId; + acpSessionId = res.acpSessionId; + agentType = res.agentType; + state.chatSessionId = chatSessionId; + state.acpSessionId = acpSessionId; + await rc.ctx.storage.put('state', state); + + // Link session → workspace (mirrors TaskRunner's ensureSessionLinked). + try { + await rc.env.DATABASE.prepare( + `UPDATE workspaces SET chat_session_id = ?, updated_at = ? WHERE id = ?` + ).bind(chatSessionId, new Date().toISOString(), workspaceId).run(); + await projectDataService.linkSessionToWorkspace( + rc.env, + projectId, + chatSessionId, + workspaceId, + ); + } catch (err) { + log.warn('trial_orchestrator.session_link_failed', { + trialId: state.trialId, + projectId, + chatSessionId, + error: err instanceof Error ? err.message : String(err), + }); + } + + log.info('trial_orchestrator.step.discovery_agent_session_created', { + trialId: state.trialId, + projectId, + chatSessionId, + acpSessionId, + agentType: res.agentType, + model: res.model, + provider: res.provider, + }); + } + + // If we re-entered after crash with both IDs but no agentType in scope, + // re-resolve so startAgentSessionOnNode below has what it needs. + if (!agentType) { + const { resolveTrialRunnerConfig } = await import('../../services/trial/trial-runner'); + agentType = resolveTrialRunnerConfig(rc.env).agentType; + } + + if (!acpSessionId || !chatSessionId) { + throw Object.assign( + new Error('discovery_agent_start lost session ids after creation'), + { permanent: true }, + ); + } + const resolvedAcpSessionId: string = acpSessionId; + const resolvedChatSessionId: string = chatSessionId; + + // Step 2: register the agent-session record on the VM agent (idempotent). + if (!state.agentSessionCreatedOnVm) { + await createAgentSessionOnNode( + nodeId, + workspaceId, + resolvedAcpSessionId, + `Trial: ${state.repoOwner}/${state.repoName}`.slice(0, 60), + rc.env, + userId, + resolvedChatSessionId, + projectId, + ); + state.agentSessionCreatedOnVm = true; + await rc.ctx.storage.put('state', state); + log.info('trial_orchestrator.step.agent_session_registered', { + trialId: state.trialId, + workspaceId, + acpSessionId, + nodeId, + }); + } + + // Step 3: mint + store the MCP token so the VM agent can call MCP tools + // (add_knowledge, create_idea) on behalf of this trial. Trials have no task + // row in D1, so we pass the trialId as the synthetic taskId — this keeps + // rate-limit keying (per-trial) and audit logging working without needing + // a schema change to McpTokenData. + if (!state.mcpToken) { + const token = generateMcpToken(); + await storeMcpToken( + rc.env.KV, + token, + { + taskId: state.trialId, + projectId, + userId, + workspaceId, + createdAt: new Date().toISOString(), + }, + rc.env, + ); + state.mcpToken = token; + await rc.ctx.storage.put('state', state); + log.info('trial_orchestrator.step.mcp_token_created', { + trialId: state.trialId, + projectId, + }); + } + + // Step 4: tell the VM agent to actually launch the subprocess with the + // discovery prompt as the initial turn. This is the missing piece that made + // previous trials hang at 90% — without this call the ACP session stayed + // in 'pending' forever. + if (!state.agentStartedOnVm) { + if (!state.mcpToken) { + throw new Error('discovery_agent_start: mcpToken missing after step 3'); + } + const initialPrompt = buildDiscoveryInitialPrompt(state.repoOwner, state.repoName); + const mcpServerUrl = `https://api.${rc.env.BASE_DOMAIN}/mcp`; + await startAgentSessionOnNode( + nodeId, + workspaceId, + resolvedAcpSessionId, + agentType, + initialPrompt, + rc.env, + userId, + { url: mcpServerUrl, token: state.mcpToken }, + ); + state.agentStartedOnVm = true; + await rc.ctx.storage.put('state', state); + log.info('trial_orchestrator.step.agent_subprocess_started', { + trialId: state.trialId, + workspaceId, + acpSessionId, + agentType, + }); + } + + // Step 5: drive the ACP session state machine pending → assigned → running. + // The trial flow has no UI claim and no VM-agent callback that moves the + // session forward on its own, so the orchestrator owns these transitions. + // The `running` transition is what the ACP bridge listens for to fire + // `trial.ready` (see apps/api/src/services/trial/bridge.ts). + if (!state.acpAssignedOnVm) { + try { + await projectDataService.transitionAcpSession( + rc.env, + projectId, + resolvedAcpSessionId, + 'assigned', + { + actorType: 'system', + reason: 'trial_orchestrator.agent_subprocess_started', + workspaceId, + nodeId, + }, + ); + state.acpAssignedOnVm = true; + await rc.ctx.storage.put('state', state); + } catch (err) { + log.warn('trial_orchestrator.acp_assign_failed', { + trialId: state.trialId, + acpSessionId, + error: err instanceof Error ? err.message : String(err), + }); + throw err; + } + } + + if (!state.acpRunningOnVm) { + try { + await projectDataService.transitionAcpSession( + rc.env, + projectId, + resolvedAcpSessionId, + 'running', + { + actorType: 'system', + reason: 'trial_orchestrator.agent_subprocess_running', + workspaceId, + nodeId, + }, + ); + state.acpRunningOnVm = true; + await rc.ctx.storage.put('state', state); + } catch (err) { + log.warn('trial_orchestrator.acp_running_failed', { + trialId: state.trialId, + acpSessionId, + error: err instanceof Error ? err.message : String(err), + }); + throw err; + } + } + + await rc.advanceToStep(state, 'running'); +} + +// --------------------------------------------------------------------------- +// Step: running (terminal for orchestrator) +// --------------------------------------------------------------------------- + +/** + * Terminal-for-orchestrator. The ACP bridge (wired into ProjectData DO's + * `transitionAcpSession`) is what actually emits `trial.ready` once the + * discovery agent produces its first assistant turn. This handler simply + * marks the DO state as completed so no further alarms fire. + * + * Note on mcpToken lifecycle: we intentionally DO NOT revoke `state.mcpToken` + * here. The orchestrator terminating at `running` means provisioning is done, + * but the discovery agent is still active inside the workspace for up to + * TRIAL_WORKSPACE_TTL_MS (default 20 min) and continues to need the token for + * `add_knowledge` / `create_idea` MCP calls. Revocation on failure happens in + * `TrialOrchestrator.failTrial()`; on success, the token is bounded by its KV + * TTL (DEFAULT_MCP_TOKEN_TTL_SECONDS = 4h) and by the workspace teardown at + * the end of the trial window. A future follow-up could shorten the token TTL + * to match the trial window for defence-in-depth. + */ +export async function handleRunning( + state: TrialOrchestratorState, + rc: TrialOrchestratorContext +): Promise { + state.completed = true; + await rc.ctx.storage.put('state', state); + log.info('trial_orchestrator.state.running', { + trialId: state.trialId, + projectId: state.projectId, + workspaceId: state.workspaceId, + }); +} + +// Re-export ulid for the DO class (keeps imports in index.ts tidy). +export { ulid }; diff --git a/apps/api/src/durable-objects/trial-orchestrator/types.ts b/apps/api/src/durable-objects/trial-orchestrator/types.ts new file mode 100644 index 000000000..12db0f0ab --- /dev/null +++ b/apps/api/src/durable-objects/trial-orchestrator/types.ts @@ -0,0 +1,106 @@ +/** + * Types for the TrialOrchestrator Durable Object. + * + * One DO instance per trialId, alarm-driven, mirrors TaskRunner's pattern. + * See `apps/api/src/durable-objects/trial-orchestrator/index.ts` for lifecycle. + */ +import type { Env } from '../../env'; + +/** + * State-machine steps for the trial provisioning flow. + * + * project_creation → node_selection → [node_provisioning → node_agent_ready] + * → workspace_creation → workspace_ready → discovery_agent_start + * → running (terminal — waits for ACP bridge to emit trial.ready) + * + * Terminal error steps: `failed`. + */ +export type TrialOrchestratorStep = + | 'project_creation' + | 'node_selection' + | 'node_provisioning' + | 'node_agent_ready' + | 'workspace_creation' + | 'workspace_ready' + | 'discovery_agent_start' + | 'running' + | 'succeeded' + | 'failed'; + +export interface TrialOrchestratorState { + version: 1; + trialId: string; + /** Public GitHub repo URL (canonical form — see parseGithubRepoUrl). */ + repoUrl: string; + repoOwner: string; + repoName: string; + currentStep: TrialOrchestratorStep; + // Resolved step outputs — persisted so retries are idempotent. + projectId: string | null; + nodeId: string | null; + autoProvisionedNode: boolean; + workspaceId: string | null; + chatSessionId: string | null; + acpSessionId: string | null; + /** + * Repo default branch detected via the GitHub public API during project_creation. + * Falls back to 'main' when the probe fails or times out. Threaded into both the + * `projects.default_branch` row and the workspace-side `git clone --branch` call. + */ + defaultBranch: string | null; + /** + * MCP token generated for this trial's discovery agent so it can call MCP tools + * (`add_knowledge`, `create_idea`, etc.) via the api Worker. Stored in KV with + * a TTL by `storeMcpToken`. + */ + mcpToken: string | null; + /** Idempotency flag — VM agent's agent-session record has been created. */ + agentSessionCreatedOnVm: boolean; + /** Idempotency flag — VM agent was told to start the agent subprocess. */ + agentStartedOnVm: boolean; + /** Idempotency flag — ACP session has been transitioned pending → assigned. */ + acpAssignedOnVm: boolean; + /** Idempotency flag — ACP session has been transitioned assigned → running. */ + acpRunningOnVm: boolean; + // Timing / retry bookkeeping. + retryCount: number; + createdAt: number; + lastStepAt: number; + /** Filled when we start waiting for the VM-agent heartbeat. */ + nodeAgentReadyStartedAt: number | null; + /** Filled when we start waiting for workspace status=running in D1. */ + workspaceReadyStartedAt: number | null; + /** Terminal — no more alarms will fire. */ + completed: boolean; + /** Failure reason (if currentStep = 'failed'). Emitted in trial.error. */ + failureReason: string | null; +} + +export interface StartTrialInput { + trialId: string; + repoUrl: string; + repoOwner: string; + repoName: string; +} + +/** + * Context object passed to extracted step handler functions — same pattern as + * TaskRunner. Lets step handlers stay as plain functions without importing the + * DO class (which depends on `cloudflare:workers`). + */ +export interface TrialOrchestratorContext { + env: Env; + ctx: DurableObjectState; + /** Advance to a new step: reset retries, persist, schedule alarm for immediate dispatch. */ + advanceToStep: (state: TrialOrchestratorState, nextStep: TrialOrchestratorStep) => Promise; + // Tunable knobs (all backed by env vars with DEFAULT_* fallbacks). + getOverallTimeoutMs: () => number; + getRetryBaseDelayMs: () => number; + getRetryMaxDelayMs: () => number; + getMaxRetries: () => number; + getWorkspaceReadyTimeoutMs: () => number; + getWorkspaceReadyPollIntervalMs: () => number; + getNodeReadyTimeoutMs: () => number; + getAgentReadyTimeoutMs: () => number; + getHeartbeatSkewMs: () => number; +} diff --git a/apps/api/src/env.ts b/apps/api/src/env.ts index 7834310e9..32ce01390 100644 --- a/apps/api/src/env.ts +++ b/apps/api/src/env.ts @@ -19,6 +19,9 @@ export interface Env { TASK_RUNNER: DurableObjectNamespace; NOTIFICATION: DurableObjectNamespace; CODEX_REFRESH_LOCK: DurableObjectNamespace; + TRIAL_COUNTER: DurableObjectNamespace; + TRIAL_EVENT_BUS: DurableObjectNamespace; + TRIAL_ORCHESTRATOR: DurableObjectNamespace; // Environment variables BASE_DOMAIN: string; VERSION: string; @@ -62,6 +65,7 @@ export interface Env { RATE_LIMIT_TERMINAL_TOKEN?: string; RATE_LIMIT_CREDENTIAL_UPDATE?: string; RATE_LIMIT_ANONYMOUS?: string; + RATE_LIMIT_TRIAL_CREATE?: string; RATE_LIMIT_IDENTITY_TOKEN?: string; RATE_LIMIT_IDENTITY_TOKEN_WINDOW_SECONDS?: string; /** @@ -431,9 +435,9 @@ export interface Env { TRIGGER_EXECUTION_LOG_RETENTION_DAYS?: string; // Days to retain completed/failed/skipped execution logs (default: 90) TRIGGER_EXECUTION_CLEANUP_ENABLED?: string; // Kill switch: "false" to disable cleanup sweep (default: enabled) TRIGGER_STALE_RECOVERY_BATCH_SIZE?: string; // Max stale executions to recover per sweep (default: 100) - // AI Inference Proxy (Cloudflare AI Gateway for trial users) + // AI Inference Proxy (Cloudflare AI Gateway — Workers AI + Anthropic) AI_PROXY_ENABLED?: string; // Kill switch: "false" to disable (default: enabled) - AI_PROXY_DEFAULT_MODEL?: string; // Default Workers AI model (default: @cf/meta/llama-4-scout-17b-16e-instruct) + AI_PROXY_DEFAULT_MODEL?: string; // Default model (default: claude-haiku-4-5-20251001) AI_PROXY_ALLOWED_MODELS?: string; // Comma-separated allowed models AI_PROXY_DAILY_INPUT_TOKEN_LIMIT?: string; // Per-user daily input token cap (default: 500000) AI_PROXY_DAILY_OUTPUT_TOKEN_LIMIT?: string; // Per-user daily output token cap (default: 200000) @@ -442,4 +446,53 @@ export interface Env { AI_PROXY_STREAM_TIMEOUT_MS?: string; // Max streaming duration in ms (default: 120000) AI_PROXY_RATE_LIMIT_WINDOW_SECONDS?: string; // Rate limit window in seconds (default: 60) AI_GATEWAY_ID?: string; // Cloudflare AI Gateway ID (default: sam) + AI_USAGE_PAGE_SIZE?: string; // AI Gateway logs page size for admin usage aggregation (default: 100) + AI_USAGE_MAX_PAGES?: string; // Max pages to iterate for AI usage aggregation (default: 20) + // Trial Onboarding (zero-friction URL-to-workspace) + TRIAL_CLAIM_TOKEN_SECRET?: string; // Secret: HMAC key for sam_trial_claim / sam_trial_fingerprint cookies + TRIAL_MONTHLY_CAP?: string; // Global cap per calendar month (default: 1500) + TRIAL_WORKSPACE_TTL_MS?: string; // Trial workspace lifetime in ms (default: 1200000 = 20 min) + TRIAL_DATA_RETENTION_HOURS?: string; // Hours to retain trial project data post-expiry (default: 168 = 7d) + TRIAL_ANONYMOUS_USER_ID?: string; // Sentinel user id (default: system_anonymous_trials) + TRIAL_AGENT_TYPE_STAGING?: string; // Agent used for trials in staging (default: opencode) + TRIAL_AGENT_TYPE_PRODUCTION?: string; // Agent used for trials in production (default: claude-code) + TRIAL_DEFAULT_WORKSPACE_PROFILE?: string; // Workspace profile (default: lightweight) + TRIALS_ENABLED_KV_KEY?: string; // KV key read by kill-switch (default: trials:enabled) + TRIAL_KILL_SWITCH_CACHE_MS?: string; // Kill-switch cache TTL in ms (default: 30000) + TRIAL_REPO_MAX_KB?: string; // Max GitHub repo size in KB (default: 512000 = 500 MB) + TRIAL_GITHUB_TIMEOUT_MS?: string; // Timeout for GitHub repo metadata probe (default: 5000) + TRIAL_COUNTER_KEEP_MONTHS?: string; // Months of counter rows to retain in DO (default: 3) + TRIAL_WAITLIST_PURGE_DAYS?: string; // Days after reset_date before notified waitlist rows are purged (default: 30) + TRIAL_CRON_ROLLOVER_CRON?: string; // Cron expression used by the monthly rollover audit (default: 0 5 1 * *) + TRIAL_CRON_WAITLIST_CLEANUP?: string; // Cron expression used by the daily waitlist cleanup (default: 0 4 * * *) + TRIAL_SSE_HEARTBEAT_MS?: string; // SSE comment heartbeat cadence (default: 15000) + TRIAL_SSE_POLL_TIMEOUT_MS?: string; // Long-poll timeout per DO fetch (default: 15000) + TRIAL_SSE_MAX_DURATION_MS?: string; // Hard cap on a single SSE connection (default: 1800000 = 30 min) + /** Deployment mode — "staging" | "production". Chooses trial agent + model. */ + ENVIRONMENT?: string; + /** Anthropic API key scoped for trial runs (production mode only). */ + ANTHROPIC_API_KEY_TRIAL?: string; + /** Override for default trial model (production mode default: claude-sonnet-4-6). */ + TRIAL_MODEL?: string; + /** Override for default trial LLM provider ("anthropic" | "workers-ai"). */ + TRIAL_LLM_PROVIDER?: string; + // TrialOrchestrator DO (alarm-driven trial provisioning) + TRIAL_ORCHESTRATOR_OVERALL_TIMEOUT_MS?: string; + TRIAL_ORCHESTRATOR_STEP_MAX_RETRIES?: string; + TRIAL_ORCHESTRATOR_RETRY_BASE_DELAY_MS?: string; + TRIAL_ORCHESTRATOR_RETRY_MAX_DELAY_MS?: string; + TRIAL_ORCHESTRATOR_WORKSPACE_READY_TIMEOUT_MS?: string; + TRIAL_ORCHESTRATOR_WORKSPACE_READY_POLL_INTERVAL_MS?: string; + TRIAL_ORCHESTRATOR_AGENT_READY_TIMEOUT_MS?: string; + TRIAL_ORCHESTRATOR_NODE_READY_TIMEOUT_MS?: string; + TRIAL_ORCHESTRATOR_HEARTBEAT_SKEW_MS?: string; + // Fast-path GitHub knowledge probes (fired from POST /api/trial/create) + TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS?: string; + TRIAL_KNOWLEDGE_MAX_EVENTS?: string; + /** Sentinel GitHub installation id used for anonymous trial projects. */ + TRIAL_ANONYMOUS_INSTALLATION_ID?: string; + /** Trial VM size override (default: DEFAULT_VM_SIZE from shared). */ + TRIAL_VM_SIZE?: string; + /** Trial VM location override (default: DEFAULT_VM_LOCATION from shared). */ + TRIAL_VM_LOCATION?: string; } diff --git a/apps/api/src/index.ts b/apps/api/src/index.ts index fbcd6577f..92343213b 100644 --- a/apps/api/src/index.ts +++ b/apps/api/src/index.ts @@ -5,6 +5,9 @@ export { NodeLifecycle } from './durable-objects/node-lifecycle'; export { NotificationService } from './durable-objects/notification'; export { ProjectData } from './durable-objects/project-data'; export { TaskRunner } from './durable-objects/task-runner'; +export { TrialCounter } from './durable-objects/trial-counter'; +export { TrialEventBus } from './durable-objects/trial-event-bus'; +export { TrialOrchestrator } from './durable-objects/trial-orchestrator'; export type { Env } from './env'; import { eq } from 'drizzle-orm'; @@ -22,6 +25,8 @@ import { AppError } from './middleware/error'; import { accountMapRoutes } from './routes/account-map'; import { activityRoutes } from './routes/activity'; import { adminRoutes } from './routes/admin'; +import { adminAIProxyRoutes } from './routes/admin-ai-proxy'; +import { adminAiUsageRoutes } from './routes/admin-ai-usage'; import { adminAnalyticsRoutes } from './routes/admin-analytics'; import { adminPlatformCredentialRoutes } from './routes/admin-platform-credentials'; import { adminQuotaRoutes } from './routes/admin-quotas'; @@ -58,6 +63,7 @@ import { tasksRoutes } from './routes/tasks'; import { terminalRoutes } from './routes/terminal'; import { transcribeRoutes } from './routes/transcribe'; import { trialRoutes } from './routes/trial'; +import { trialOnboardingRoutes } from './routes/trial/index'; import { triggersRoutes } from './routes/triggers'; import { ttsRoutes } from './routes/tts'; import { uiGovernanceRoutes } from './routes/ui-governance'; @@ -69,6 +75,9 @@ import { runCronTriggerSweep } from './scheduled/cron-triggers'; import { runNodeCleanupSweep } from './scheduled/node-cleanup'; import { runObservabilityPurge } from './scheduled/observability-purge'; import { recoverStuckTasks } from './scheduled/stuck-tasks'; +import { runTrialExpireSweep } from './scheduled/trial-expire'; +import { runTrialRolloverAudit } from './scheduled/trial-rollover'; +import { runTrialWaitlistCleanup } from './scheduled/trial-waitlist-cleanup'; import { runTriggerExecutionCleanup } from './scheduled/trigger-execution-cleanup'; import { GcpApiError, sanitizeGcpError } from './services/gcp-errors'; import { signTerminalToken } from './services/jwt'; @@ -384,7 +393,9 @@ app.route('/api/projects/:projectId/knowledge', knowledgeRoutes); app.route('/api/projects', projectDeploymentRoutes); app.route('/api/deployment', gcpDeployCallbackRoute); app.route('/api/admin', adminRoutes); +app.route('/api/admin/ai-proxy', adminAIProxyRoutes); app.route('/api/admin/analytics', adminAnalyticsRoutes); +app.route('/api/admin/analytics/ai-usage', adminAiUsageRoutes); app.route('/api/admin/platform-credentials', adminPlatformCredentialRoutes); app.route('/api/admin/quotas', adminQuotaRoutes); app.route('/api/admin/usage', adminUsageRoutes); @@ -393,6 +404,7 @@ app.route('/api/account-map', accountMapRoutes); app.route('/api/dashboard', dashboardRoutes); app.route('/api/notifications', notificationRoutes); app.route('/api', trialRoutes); +app.route('/api/trial', trialOnboardingRoutes); app.route('/api/gcp', gcpRoutes); app.route('/ai/v1', aiProxyRoutes); app.route('/auth/google', googleAuthRoutes); @@ -429,20 +441,35 @@ export default { /** * Scheduled (cron) handler for background tasks. - * Two cron schedules: - * - Every 5 minutes: operational cleanup (provisioning, nodes, tasks, observability) + * Cron schedules: + * - Every 5 minutes: operational cleanup (provisioning, nodes, tasks, observability, trial expiry) * - Daily at 03:00 UTC: analytics event forwarding to external platforms + * - Daily at 04:00 UTC (configurable via TRIAL_CRON_WAITLIST_CLEANUP): trial waitlist purge + * - Monthly at 03:00 UTC on the 1st (configurable via TRIAL_CRON_ROLLOVER_CRON): trial counter rollover audit */ async scheduled( controller: ScheduledController, env: Env, ctx: ExecutionContext ): Promise { + const rolloverCron = env.TRIAL_CRON_ROLLOVER_CRON ?? '0 5 1 * *'; + const waitlistCleanupCron = env.TRIAL_CRON_WAITLIST_CLEANUP ?? '0 4 * * *'; + const isDailyForward = controller.cron === '0 3 * * *'; + const isTrialRollover = controller.cron === rolloverCron; + const isTrialWaitlistCleanup = controller.cron === waitlistCleanupCron; + + const cronType = isDailyForward + ? 'daily-forward' + : isTrialRollover + ? 'trial-rollover' + : isTrialWaitlistCleanup + ? 'trial-waitlist-cleanup' + : 'sweep'; log.info('cron.started', { cron: controller.cron, - type: isDailyForward ? 'daily-forward' : 'sweep', + type: cronType, }); // Daily analytics forwarding (Phase 4) — use ctx.waitUntil to keep the @@ -463,6 +490,33 @@ export default { return; } + // Monthly trial counter rollover audit (prune old DO counter rows, verify month-key drift). + if (isTrialRollover) { + ctx.waitUntil((async () => { + const rollover = await runTrialRolloverAudit(env); + log.info('cron.completed', { + cron: controller.cron, + type: 'trial-rollover', + trialRolloverMonthKey: rollover.monthKey, + trialRolloverPruned: rollover.pruned, + }); + })()); + return; + } + + // Daily trial waitlist cleanup (purge notified-and-aged rows). + if (isTrialWaitlistCleanup) { + ctx.waitUntil((async () => { + const waitlist = await runTrialWaitlistCleanup(env); + log.info('cron.completed', { + cron: controller.cron, + type: 'trial-waitlist-cleanup', + trialWaitlistPurged: waitlist.purged, + }); + })()); + return; + } + // 5-minute operational sweep // Check for stuck provisioning workspaces const timedOut = await checkProvisioningTimeouts(env.DATABASE, env, env.OBSERVABILITY_DATABASE); @@ -489,6 +543,10 @@ export default { // Close orphaned compute_usage records const computeUsageClosed = await runComputeUsageCleanup(env); + // Expire stale pending/ready trial rows (cap slot is NOT refunded — it was + // consumed for the month). + const trialExpire = await runTrialExpireSweep(env); + log.info('cron.completed', { cron: controller.cron, type: 'sweep', @@ -517,6 +575,7 @@ export default { triggerExecRetentionPurged: triggerCleanup.retentionPurged, triggerExecCleanupErrors: triggerCleanup.errors, computeUsageOrphansClosed: computeUsageClosed, + trialExpired: trialExpire.expired, }); }, }; diff --git a/apps/api/src/middleware/rate-limit.ts b/apps/api/src/middleware/rate-limit.ts index 0c289e561..34f6ce1de 100644 --- a/apps/api/src/middleware/rate-limit.ts +++ b/apps/api/src/middleware/rate-limit.ts @@ -37,6 +37,11 @@ export const DEFAULT_RATE_LIMITS = { CLIENT_ERRORS: 200, IDENTITY_TOKEN: 60, ANALYTICS_INGEST: 60, + // Tighter limit for anonymous trial creation: each call spawns a DO, + // fires ~4 GitHub API calls, and consumes a monthly trial slot. + TRIAL_CREATE: 10, + // SSE events endpoint — short window to prevent connection storms. + TRIAL_SSE: 30, } as const; /** Default time window (1 hour in seconds) */ @@ -237,3 +242,17 @@ export function rateLimitAnonymous(env: Env): MiddlewareHandler<{ Bindings: Env }); } +/** + * Per-IP rate limit for `POST /api/trial/create`. Tighter than the general + * anonymous bucket because each trial create allocates a Durable Object, + * makes ~4 outbound GitHub API calls, and consumes a monthly trial slot. + * Default: 10 requests per hour per IP. + */ +export function rateLimitTrialCreate(env: Env): MiddlewareHandler<{ Bindings: Env }> { + return rateLimit({ + limit: getRateLimit(env, 'TRIAL_CREATE'), + keyPrefix: 'trial-create', + useIp: true, + }); +} + diff --git a/apps/api/src/routes/admin-ai-proxy.ts b/apps/api/src/routes/admin-ai-proxy.ts new file mode 100644 index 000000000..84019b236 --- /dev/null +++ b/apps/api/src/routes/admin-ai-proxy.ts @@ -0,0 +1,145 @@ +/** + * Admin AI Proxy configuration routes. + * + * GET /api/admin/ai-proxy/config — read current config (default model, available models) + * PUT /api/admin/ai-proxy/config — update default model selection + * + * Config is stored in KV so admins can change the default model without redeploying. + */ +import { + AI_PROXY_DEFAULT_MODEL_KV_KEY, + type AIProxyConfig, + DEFAULT_AI_PROXY_MODEL, + PLATFORM_AI_MODELS, +} from '@simple-agent-manager/shared'; +import { drizzle } from 'drizzle-orm/d1'; +import { Hono } from 'hono'; + +import * as schema from '../db/schema'; +import type { Env } from '../env'; +import { log } from '../lib/logger'; +import { getCredentialEncryptionKey } from '../lib/secrets'; +import { requireApproved, requireAuth, requireSuperadmin } from '../middleware/auth'; +import { errors } from '../middleware/error'; +import { getPlatformAgentCredential } from '../services/platform-credentials'; + +const adminAIProxyRoutes = new Hono<{ Bindings: Env }>(); + +adminAIProxyRoutes.use('/*', requireAuth(), requireApproved(), requireSuperadmin()); + +/** + * Check whether an Anthropic platform credential exists and is enabled. + * Anthropic models require this to be present before they can be selected. + */ +async function hasAnthropicCredential(env: Env): Promise { + try { + const db = drizzle(env.DATABASE, { schema }); + const encryptionKey = getCredentialEncryptionKey(env); + const cred = await getPlatformAgentCredential(db, 'claude-code', encryptionKey); + return !!cred?.credential; + } catch { + return false; + } +} + +/** + * GET /api/admin/ai-proxy/config + * + * Returns the current AI proxy configuration including: + * - The active default model (KV override > env var > shared constant) + * - Available models with provider info and availability status + * - Whether an Anthropic credential is configured + */ +adminAIProxyRoutes.get('/config', async (c) => { + const kvConfig = await c.env.KV.get(AI_PROXY_DEFAULT_MODEL_KV_KEY); + const parsed: AIProxyConfig | null = kvConfig ? JSON.parse(kvConfig) : null; + + const hasAnthropic = await hasAnthropicCredential(c.env); + + // Effective default: KV override > env var > shared constant + const effectiveDefault = parsed?.defaultModel + ?? c.env.AI_PROXY_DEFAULT_MODEL + ?? DEFAULT_AI_PROXY_MODEL; + + const models = PLATFORM_AI_MODELS.map((m) => ({ + ...m, + available: m.provider === 'workers-ai' || hasAnthropic, + })); + + return c.json({ + defaultModel: effectiveDefault, + source: parsed ? 'admin' as const : (c.env.AI_PROXY_DEFAULT_MODEL ? 'env' as const : 'default' as const), + updatedAt: parsed?.updatedAt ?? null, + hasAnthropicCredential: hasAnthropic, + models, + }); +}); + +/** + * PUT /api/admin/ai-proxy/config + * + * Update the default model. Validates that: + * - The model ID is in the PLATFORM_AI_MODELS list + * - If an Anthropic model is selected, a platform credential exists + */ +adminAIProxyRoutes.put('/config', async (c) => { + const body = await c.req.json<{ defaultModel: string }>(); + + if (!body.defaultModel || typeof body.defaultModel !== 'string') { + throw errors.badRequest('defaultModel is required'); + } + + const model = PLATFORM_AI_MODELS.find((m) => m.id === body.defaultModel); + if (!model) { + throw errors.badRequest(`Unknown model: ${body.defaultModel}. Available: ${PLATFORM_AI_MODELS.map((m) => m.id).join(', ')}`); + } + + // Anthropic models require a platform credential + if (model.provider === 'anthropic') { + const hasAnthropic = await hasAnthropicCredential(c.env); + if (!hasAnthropic) { + throw errors.badRequest( + 'Cannot select an Anthropic model without a Claude Code platform credential. ' + + 'Add one on the Credentials tab first.', + ); + } + } + + const config: AIProxyConfig = { + defaultModel: body.defaultModel, + updatedAt: new Date().toISOString(), + }; + + await c.env.KV.put(AI_PROXY_DEFAULT_MODEL_KV_KEY, JSON.stringify(config)); + + log.info('admin.ai_proxy.config_updated', { + defaultModel: body.defaultModel, + provider: model.provider, + }); + + return c.json({ + defaultModel: config.defaultModel, + source: 'admin' as const, + updatedAt: config.updatedAt, + }); +}); + +/** + * DELETE /api/admin/ai-proxy/config + * + * Reset to platform default (removes KV override). + */ +adminAIProxyRoutes.delete('/config', async (c) => { + await c.env.KV.delete(AI_PROXY_DEFAULT_MODEL_KV_KEY); + + log.info('admin.ai_proxy.config_reset', {}); + + const effectiveDefault = c.env.AI_PROXY_DEFAULT_MODEL ?? DEFAULT_AI_PROXY_MODEL; + return c.json({ + defaultModel: effectiveDefault, + source: c.env.AI_PROXY_DEFAULT_MODEL ? 'env' as const : 'default' as const, + updatedAt: null, + }); +}); + +export { adminAIProxyRoutes }; diff --git a/apps/api/src/routes/admin-ai-usage.ts b/apps/api/src/routes/admin-ai-usage.ts new file mode 100644 index 000000000..98fd52562 --- /dev/null +++ b/apps/api/src/routes/admin-ai-usage.ts @@ -0,0 +1,275 @@ +/** + * Admin AI Usage analytics — queries Cloudflare AI Gateway Logs API for + * token usage, cost, and request metrics. + * + * Mounts at /api/admin/analytics/ai-usage (registered in index.ts). + */ +import { Hono } from 'hono'; + +import type { Env } from '../env'; +import { log } from '../lib/logger'; +import { requireApproved, requireAuth, requireSuperadmin } from '../middleware/auth'; +import { errors } from '../middleware/error'; + +/** Default number of log entries to fetch per page. Override via AI_USAGE_PAGE_SIZE env var. CF max is 50. */ +const DEFAULT_PAGE_SIZE = 50; +/** Maximum pages to iterate when aggregating. Override via AI_USAGE_MAX_PAGES env var. */ +const DEFAULT_MAX_PAGES = 20; + +const VALID_PERIODS = ['24h', '7d', '30d', '90d'] as const; +type Period = (typeof VALID_PERIODS)[number]; + +function parsePeriod(raw: string | undefined): Period { + return (VALID_PERIODS as readonly string[]).includes(raw ?? '') ? (raw as Period) : '7d'; +} + +/** Convert period to ISO date string for start_date filter. */ +function periodToStartDate(period: Period): string { + const now = new Date(); + switch (period) { + case '24h': + now.setDate(now.getDate() - 1); + break; + case '7d': + now.setDate(now.getDate() - 7); + break; + case '30d': + now.setDate(now.getDate() - 30); + break; + case '90d': + now.setDate(now.getDate() - 90); + break; + } + return now.toISOString(); +} + +interface AIGatewayLogEntry { + id: string; + model: string; + provider: string; + tokens_in: number; + tokens_out: number; + cost: number; + success: boolean; + cached: boolean; + created_at: string; + duration: number; + metadata: Record | null; +} + +interface AIGatewayLogsResponse { + result: AIGatewayLogEntry[]; + result_info: { + page: number; + per_page: number; + count: number; + total_count: number; + total_pages: number; + }; + success: boolean; + errors: unknown[]; +} + +/** Fetch a single page of AI Gateway logs. */ +async function fetchGatewayLogs( + env: Env, + gatewayId: string, + params: URLSearchParams, +): Promise { + const accountId = env.CF_ACCOUNT_ID; + if (!accountId) throw errors.internal('CF_ACCOUNT_ID is not configured'); + + const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai-gateway/gateways/${gatewayId}/logs?${params.toString()}`; + + const resp = await fetch(url, { + headers: { + Authorization: `Bearer ${env.CF_API_TOKEN}`, + 'Content-Type': 'application/json', + }, + }); + + if (!resp.ok) { + const body = await resp.text(); + log.error('admin_ai_usage.gateway_api_error', { + status: resp.status, + body: body.slice(0, 500), + url: url.replace(/Bearer\s+\S+/, 'Bearer [REDACTED]'), + }); + throw errors.internal(`AI Gateway API error (${resp.status}). Check admin logs for details.`); + } + + return resp.json() as Promise; +} + +export interface AiUsageByModel { + model: string; + provider: string; + requests: number; + inputTokens: number; + outputTokens: number; + totalTokens: number; + costUsd: number; + cachedRequests: number; + errorRequests: number; +} + +export interface AiUsageByDay { + date: string; + requests: number; + inputTokens: number; + outputTokens: number; + costUsd: number; +} + +export interface AiUsageSummary { + totalRequests: number; + totalInputTokens: number; + totalOutputTokens: number; + totalCostUsd: number; + trialRequests: number; + trialCostUsd: number; + cachedRequests: number; + errorRequests: number; + byModel: AiUsageByModel[]; + byDay: AiUsageByDay[]; + period: string; +} + +const adminAiUsageRoutes = new Hono<{ Bindings: Env }>(); + +adminAiUsageRoutes.use('/*', requireAuth(), requireApproved(), requireSuperadmin()); + +/** + * GET /ai-usage — Aggregated AI usage from AI Gateway logs. + * Query params: ?period=7d (24h|7d|30d|90d) + */ +adminAiUsageRoutes.get('/', async (c) => { + const period = parsePeriod(c.req.query('period')); + const gatewayId = c.env.AI_GATEWAY_ID; + if (!gatewayId) { + // AI Gateway is optional — self-hosters who don't use it get an empty summary + // instead of a 500. The admin dashboard renders "no data" gracefully. + return c.json({ + totalRequests: 0, + totalInputTokens: 0, + totalOutputTokens: 0, + totalCostUsd: 0, + trialRequests: 0, + trialCostUsd: 0, + cachedRequests: 0, + errorRequests: 0, + byModel: [], + byDay: [], + period, + } satisfies AiUsageSummary); + } + const pageSize = parseInt(c.env.AI_USAGE_PAGE_SIZE || '', 10) || DEFAULT_PAGE_SIZE; + const maxPages = parseInt(c.env.AI_USAGE_MAX_PAGES || '', 10) || DEFAULT_MAX_PAGES; + const startDate = periodToStartDate(period); + + // Aggregate across all pages + const modelMap = new Map(); + const dayMap = new Map(); + let totalRequests = 0; + let totalInputTokens = 0; + let totalOutputTokens = 0; + let totalCostUsd = 0; + let trialRequests = 0; + let trialCostUsd = 0; + let cachedRequests = 0; + let errorRequests = 0; + + for (let page = 1; page <= maxPages; page++) { + const params = new URLSearchParams({ + page: page.toString(), + per_page: pageSize.toString(), + start_date: startDate, + order_by: 'created_at', + order_by_direction: 'desc', + }); + + const resp = await fetchGatewayLogs(c.env, gatewayId, params); + + for (const entry of resp.result) { + totalRequests++; + totalInputTokens += entry.tokens_in || 0; + totalOutputTokens += entry.tokens_out || 0; + totalCostUsd += entry.cost || 0; + + if (entry.cached) cachedRequests++; + if (!entry.success) errorRequests++; + + // Check metadata for trial flag + if (entry.metadata?.trialId) { + trialRequests++; + trialCostUsd += entry.cost || 0; + } + + // Aggregate by model + const modelKey = entry.model || 'unknown'; + const existing = modelMap.get(modelKey); + if (existing) { + existing.requests++; + existing.inputTokens += entry.tokens_in || 0; + existing.outputTokens += entry.tokens_out || 0; + existing.totalTokens += (entry.tokens_in || 0) + (entry.tokens_out || 0); + existing.costUsd += entry.cost || 0; + if (entry.cached) existing.cachedRequests++; + if (!entry.success) existing.errorRequests++; + } else { + modelMap.set(modelKey, { + model: modelKey, + provider: entry.provider || 'unknown', + requests: 1, + inputTokens: entry.tokens_in || 0, + outputTokens: entry.tokens_out || 0, + totalTokens: (entry.tokens_in || 0) + (entry.tokens_out || 0), + costUsd: entry.cost || 0, + cachedRequests: entry.cached ? 1 : 0, + errorRequests: entry.success ? 0 : 1, + }); + } + + // Aggregate by day + const dayKey = entry.created_at?.slice(0, 10) || 'unknown'; + const dayEntry = dayMap.get(dayKey); + if (dayEntry) { + dayEntry.requests++; + dayEntry.inputTokens += entry.tokens_in || 0; + dayEntry.outputTokens += entry.tokens_out || 0; + dayEntry.costUsd += entry.cost || 0; + } else { + dayMap.set(dayKey, { + date: dayKey, + requests: 1, + inputTokens: entry.tokens_in || 0, + outputTokens: entry.tokens_out || 0, + costUsd: entry.cost || 0, + }); + } + } + + // Stop if we've gotten all results + if (resp.result.length < pageSize || page >= resp.result_info.total_pages) { + break; + } + } + + const summary: AiUsageSummary = { + totalRequests, + totalInputTokens, + totalOutputTokens, + totalCostUsd, + trialRequests, + trialCostUsd, + cachedRequests, + errorRequests, + byModel: Array.from(modelMap.values()).sort((a, b) => b.costUsd - a.costUsd), + byDay: Array.from(dayMap.values()).sort((a, b) => a.date.localeCompare(b.date)), + period, + }; + + return c.json(summary); +}); + +export { adminAiUsageRoutes }; diff --git a/apps/api/src/routes/ai-proxy.ts b/apps/api/src/routes/ai-proxy.ts index bddb83299..27edb4c7e 100644 --- a/apps/api/src/routes/ai-proxy.ts +++ b/apps/api/src/routes/ai-proxy.ts @@ -1,10 +1,9 @@ /** - * AI inference proxy — transparent pass-through to Cloudflare AI Gateway. + * AI inference proxy — routes to Workers AI or Anthropic via Cloudflare AI Gateway. * - * The AI Gateway provides an OpenAI-compatible endpoint that natively supports - * tools, streaming, and all chat completion features. This proxy handles - * SAM-specific concerns (auth, rate limiting, token budgets) and forwards - * requests transparently — no format translation needed. + * For Workers AI models (@cf/*): transparent pass-through (OpenAI-compatible format). + * For Anthropic models (claude-*): translates OpenAI format → Anthropic Messages API, + * forwards through AI Gateway's /anthropic path, translates response back. * * Auth: Bearer token in Authorization header (workspace callback token). * Rate limit: per-user RPM via KV. @@ -13,6 +12,8 @@ * Mount point: app.route('/ai/v1', aiProxyRoutes) in index.ts. */ import { + AI_PROXY_DEFAULT_MODEL_KV_KEY, + type AIProxyConfig, DEFAULT_AI_PROXY_ALLOWED_MODELS, DEFAULT_AI_PROXY_MAX_INPUT_TOKENS_PER_REQUEST, DEFAULT_AI_PROXY_MODEL, @@ -26,25 +27,45 @@ import { Hono } from 'hono'; import * as schema from '../db/schema'; import type { Env } from '../env'; import { log } from '../lib/logger'; +import { getCredentialEncryptionKey } from '../lib/secrets'; import { checkRateLimit, createRateLimitKey, getCurrentWindowStart } from '../middleware/rate-limit'; +import { + createAnthropicToOpenAIStream, + translateRequestToAnthropic, + translateResponseToOpenAI, +} from '../services/ai-anthropic-translate'; import { checkTokenBudget } from '../services/ai-token-budget'; import { verifyCallbackToken } from '../services/jwt'; +import { getPlatformAgentCredential } from '../services/platform-credentials'; const aiProxyRoutes = new Hono<{ Bindings: Env }>(); +// ============================================================================= +// Model Routing +// ============================================================================= + +/** Check if a model ID is an Anthropic model (requires format translation). */ +function isAnthropicModel(modelId: string): boolean { + return modelId.startsWith('claude-'); +} + /** Parse allowed models from env or use defaults, normalizing prefixes. */ function getAllowedModels(env: Env): Set { const raw = env.AI_PROXY_ALLOWED_MODELS || DEFAULT_AI_PROXY_ALLOWED_MODELS; return new Set(raw.split(',').map((m) => m.trim()).filter(Boolean).map((m) => normalizeModelId(m))); } -/** Normalize model ID: ensure @cf/ prefix for Workers AI models. */ +/** Normalize model ID: ensure @cf/ prefix for Workers AI models, leave Anthropic models as-is. */ function normalizeModelId(model: string): string { let resolved = model; // Strip workers-ai/ prefix that OpenCode may prepend if (resolved.startsWith('workers-ai/')) { resolved = resolved.slice('workers-ai/'.length); } + // Anthropic models don't get the @cf/ prefix + if (isAnthropicModel(resolved)) { + return resolved; + } // Add @cf/ prefix if missing — Workers AI requires the full @cf/ path. if (!resolved.startsWith('@cf/') && !resolved.startsWith('@hf/')) { resolved = `@cf/${resolved}`; @@ -52,29 +73,51 @@ function normalizeModelId(model: string): string { return resolved; } -/** Resolve model from request, falling back to default. */ -function resolveModelId(model: string | undefined, env: Env): string { - if (!model) return normalizeModelId(env.AI_PROXY_DEFAULT_MODEL || DEFAULT_AI_PROXY_MODEL); - return normalizeModelId(model); +/** Resolve model from request, falling back to admin KV override > env var > shared constant. */ +async function resolveModelId(model: string | undefined, env: Env): Promise { + if (model) return normalizeModelId(model); + + // Priority: KV (admin-set) > env var > shared constant + const kvConfig = await env.KV.get(AI_PROXY_DEFAULT_MODEL_KV_KEY); + if (kvConfig) { + try { + const parsed: AIProxyConfig = JSON.parse(kvConfig); + if (parsed.defaultModel) return normalizeModelId(parsed.defaultModel); + } catch { /* ignore corrupt KV data, fall through */ } + } + + return normalizeModelId(env.AI_PROXY_DEFAULT_MODEL || DEFAULT_AI_PROXY_MODEL); } -/** - * Build the upstream URL for Workers AI chat completions. - * - * When AI_GATEWAY_ID is set, routes through the AI Gateway for caching, - * logging, and analytics. Otherwise falls back to the Workers AI REST API. - */ -function buildUpstreamUrl(env: Env): string { +// ============================================================================= +// Upstream URL Builders +// ============================================================================= + +/** Build upstream URL for Workers AI (OpenAI-compatible). */ +function buildWorkersAIUrl(env: Env): string { const gatewayId = env.AI_GATEWAY_ID; if (gatewayId) { return `https://gateway.ai.cloudflare.com/v1/${env.CF_ACCOUNT_ID}/${gatewayId}/workers-ai/v1/chat/completions`; } - // Fallback: Workers AI OpenAI-compatible REST API (no gateway needed) return `https://api.cloudflare.com/client/v4/accounts/${env.CF_ACCOUNT_ID}/ai/v1/chat/completions`; } +/** Build upstream URL for Anthropic Messages API via AI Gateway. */ +function buildAnthropicUrl(env: Env): string { + const gatewayId = env.AI_GATEWAY_ID; + if (gatewayId) { + return `https://gateway.ai.cloudflare.com/v1/${env.CF_ACCOUNT_ID}/${gatewayId}/anthropic/v1/messages`; + } + // Fallback: direct Anthropic API (no gateway monitoring) + return 'https://api.anthropic.com/v1/messages'; +} + +// ============================================================================= +// Input Token Estimation +// ============================================================================= + /** - * Estimate input tokens from messages (rough: 1 token ≈ 4 chars). + * Estimate input tokens from messages (rough: 1 token ~ 4 chars). * Handles both string and array content formats. */ function estimateInputTokens(messages: Array<{ role: string; content: unknown }>): number { @@ -90,12 +133,135 @@ function estimateInputTokens(messages: Array<{ role: string; content: unknown }> return Math.ceil(totalChars / 4); } +// ============================================================================= +// Forwarding Functions +// ============================================================================= + +/** Forward request to Workers AI (transparent OpenAI-format pass-through). */ +async function forwardToWorkersAI( + env: Env, + body: Record, + modelId: string, + aigMetadata: string, +): Promise { + const gatewayUrl = buildWorkersAIUrl(env); + const gatewayBody = { ...body, model: modelId }; + + const response = await fetch(gatewayUrl, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${env.CF_API_TOKEN}`, + 'Content-Type': 'application/json', + 'cf-aig-metadata': aigMetadata, + }, + body: JSON.stringify(gatewayBody), + }); + + if (!response.ok) { + const errorText = await response.text(); + log.error('ai_proxy.workers_ai_error', { + status: response.status, + body: errorText.slice(0, 500), + }); + return new Response(JSON.stringify({ + error: { + message: `AI inference failed (${response.status}). Please try again.`, + type: 'server_error', + }, + }), { status: response.status, headers: { 'Content-Type': 'application/json' } }); + } + + // Pass through transparently + const responseHeaders = new Headers(); + const contentType = response.headers.get('content-type'); + if (contentType) responseHeaders.set('Content-Type', contentType); + if (body.stream) { + responseHeaders.set('Cache-Control', 'no-cache'); + responseHeaders.set('Connection', 'keep-alive'); + responseHeaders.set('X-Accel-Buffering', 'no'); + } + + return new Response(response.body, { status: response.status, headers: responseHeaders }); +} + +/** Forward request to Anthropic via AI Gateway (with format translation). */ +async function forwardToAnthropic( + env: Env, + body: Record, + modelId: string, + aigMetadata: string, + anthropicApiKey: string, +): Promise { + + // Translate OpenAI format → Anthropic Messages format + const anthropicRequest = translateRequestToAnthropic(body, modelId); + const gatewayUrl = buildAnthropicUrl(env); + + const response = await fetch(gatewayUrl, { + method: 'POST', + headers: { + 'x-api-key': anthropicApiKey, + 'anthropic-version': '2023-06-01', + 'Content-Type': 'application/json', + 'cf-aig-metadata': aigMetadata, + }, + body: JSON.stringify(anthropicRequest), + }); + + if (!response.ok) { + const errorText = await response.text(); + log.error('ai_proxy.anthropic_error', { + status: response.status, + body: errorText.slice(0, 500), + }); + return new Response(JSON.stringify({ + error: { + message: `AI inference failed (${response.status}). Please try again.`, + type: 'server_error', + }, + }), { status: response.status, headers: { 'Content-Type': 'application/json' } }); + } + + // Non-streaming: translate response + if (!body.stream) { + const anthropicResponse = await response.json() as Record; + const openAIResponse = translateResponseToOpenAI(anthropicResponse as never); + return new Response(JSON.stringify(openAIResponse), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }); + } + + // Streaming: pipe through format translation transform + if (!response.body) { + return new Response(JSON.stringify({ + error: { message: 'No response body from Anthropic', type: 'server_error' }, + }), { status: 502, headers: { 'Content-Type': 'application/json' } }); + } + + const transformStream = createAnthropicToOpenAIStream(modelId); + const translatedBody = response.body.pipeThrough(transformStream); + + return new Response(translatedBody, { + status: 200, + headers: { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'X-Accel-Buffering': 'no', + }, + }); +} + +// ============================================================================= +// Main Route Handler +// ============================================================================= + /** - * POST /chat/completions — Transparent proxy to Cloudflare AI Gateway. + * POST /chat/completions — Proxy to AI Gateway (Workers AI or Anthropic). * - * Accepts the full OpenAI chat completions format (messages, tools, tool_choice, - * stream, temperature, etc.) and forwards it to the AI Gateway. The response is - * streamed back without modification. + * Accepts the full OpenAI chat completions format. For Anthropic models, + * performs format translation transparently. */ aiProxyRoutes.post('/chat/completions', async (c) => { // Kill switch @@ -124,10 +290,10 @@ aiProxyRoutes.post('/chat/completions', async (c) => { const workspaceId = tokenPayload.workspace; - // --- Resolve workspaceId → userId --- + // --- Resolve workspaceId → userId + projectId --- const db = drizzle(c.env.DATABASE, { schema }); const workspace = await db - .select({ userId: schema.workspaces.userId }) + .select({ userId: schema.workspaces.userId, projectId: schema.workspaces.projectId }) .from(schema.workspaces) .where(eq(schema.workspaces.id, workspaceId)) .get(); @@ -138,6 +304,18 @@ aiProxyRoutes.post('/chat/completions', async (c) => { } const userId = workspace.userId; + const projectId = workspace.projectId; + + // Check if this workspace belongs to a trial + let trialId: string | undefined; + if (projectId) { + const trial = await db + .select({ id: schema.trials.id }) + .from(schema.trials) + .where(eq(schema.trials.projectId, projectId)) + .get(); + trialId = trial?.id; + } // --- Rate limit: per-user RPM --- const rpmLimit = parseInt(c.env.AI_PROXY_RATE_LIMIT_RPM || '', 10) || DEFAULT_AI_PROXY_RATE_LIMIT_RPM; @@ -165,7 +343,7 @@ aiProxyRoutes.post('/chat/completions', async (c) => { ); } - // --- Parse request body (accept any valid JSON — Gateway handles validation) --- + // --- Parse request body --- let body: Record; try { body = await c.req.json() as Record; @@ -179,7 +357,7 @@ aiProxyRoutes.post('/chat/completions', async (c) => { } // --- Resolve and validate model --- - const modelId = resolveModelId(body.model as string | undefined, c.env); + const modelId = await resolveModelId(body.model as string | undefined, c.env); const allowedModels = getAllowedModels(c.env); if (!allowedModels.has(modelId)) { return c.json({ @@ -218,97 +396,72 @@ aiProxyRoutes.post('/chat/completions', async (c) => { }, 400); } - // --- Forward to AI Gateway (transparent pass-through) --- - // Set the resolved model in the body and forward everything else as-is. - // The Gateway handles tools, tool_choice, streaming, temperature, etc. natively. - const gatewayBody = { ...body, model: modelId }; - const gatewayUrl = buildUpstreamUrl(c.env); - - // Attach per-user metadata for AI Gateway analytics (max 5 fields). - // Enables per-user token usage tracking, cost attribution, and log filtering. + // --- Per-user metadata for AI Gateway analytics --- const aigMetadata = JSON.stringify({ userId, workspaceId, + projectId: projectId ?? undefined, + trialId: trialId ?? undefined, modelId, stream: !!body.stream, hasTools: !!body.tools, }); - log.info('ai_proxy.gateway_forward', { + const isAnthropic = isAnthropicModel(modelId); + + // For Anthropic models, resolve the API key from platform credentials (admin-managed). + // The key is stored as a platform credential for agent type 'claude-code' since + // that's the agent type that uses Anthropic API keys. + let anthropicApiKey: string | undefined; + if (isAnthropic) { + const encryptionKey = getCredentialEncryptionKey(c.env); + const platformCred = await getPlatformAgentCredential(db, 'claude-code', encryptionKey); + anthropicApiKey = platformCred?.credential; + if (!anthropicApiKey) { + return c.json({ + error: { + message: 'No Anthropic API key configured. An admin must add a Claude Code platform credential.', + type: 'server_error', + }, + }, 503); + } + } + + log.info('ai_proxy.forward', { userId, workspaceId, modelId, + provider: isAnthropic ? 'anthropic' : 'workers-ai', messageCount: (body.messages as unknown[]).length, hasTools: !!body.tools, stream: !!body.stream, estimatedInputTokens, - gatewayUrl, }); try { - const gatewayResponse = await fetch(gatewayUrl, { - method: 'POST', - headers: { - 'Authorization': `Bearer ${c.env.CF_API_TOKEN}`, - 'Content-Type': 'application/json', - 'cf-aig-metadata': aigMetadata, - }, - body: JSON.stringify(gatewayBody), - }); - - if (!gatewayResponse.ok) { - const errorText = await gatewayResponse.text(); - log.error('ai_proxy.gateway_error', { - userId, - workspaceId, - modelId, - status: gatewayResponse.status, - error: errorText.slice(0, 500), - cfRay: gatewayResponse.headers.get('cf-ray'), - }); - return c.json({ - error: { - message: `AI Gateway error (${gatewayResponse.status}): ${errorText.slice(0, 200)}`, - type: 'server_error', - }, - }, gatewayResponse.status as 500); - } + const response = isAnthropic + ? await forwardToAnthropic(c.env, body, modelId, aigMetadata, anthropicApiKey!) + : await forwardToWorkersAI(c.env, body, modelId, aigMetadata); - // Pass through the response transparently — including streaming SSE. - // The Gateway already returns proper OpenAI-format responses. - log.info('ai_proxy.gateway_response', { + log.info('ai_proxy.response', { userId, workspaceId, modelId, - status: gatewayResponse.status, - contentType: gatewayResponse.headers.get('content-type'), - cfRay: gatewayResponse.headers.get('cf-ray'), - aigLogId: gatewayResponse.headers.get('cf-aig-log-id'), + provider: isAnthropic ? 'anthropic' : 'workers-ai', + status: response.status, }); - // Build response headers — preserve content-type and streaming headers from Gateway - const responseHeaders = new Headers(); - const contentType = gatewayResponse.headers.get('content-type'); - if (contentType) responseHeaders.set('Content-Type', contentType); - if (body.stream) { - responseHeaders.set('Cache-Control', 'no-cache'); - responseHeaders.set('Connection', 'keep-alive'); - responseHeaders.set('X-Accel-Buffering', 'no'); - } - - return new Response(gatewayResponse.body, { - status: gatewayResponse.status, - headers: responseHeaders, - }); + return response; } catch (err) { - log.error('ai_proxy.gateway_fetch_error', { + log.error('ai_proxy.fetch_error', { userId, workspaceId, modelId, + provider: isAnthropic ? 'anthropic' : 'workers-ai', error: err instanceof Error ? err.message : String(err), }); return c.json({ - error: { message: 'Failed to reach AI Gateway. Please try again.', type: 'server_error' }, + error: { message: 'Failed to reach upstream. Please try again.', type: 'server_error' }, }, 502); } }); @@ -324,11 +477,11 @@ aiProxyRoutes.get('/models', async (c) => { id, object: 'model' as const, created: 0, - owned_by: 'cloudflare', + owned_by: isAnthropicModel(id) ? 'anthropic' : 'cloudflare', })); return c.json({ object: 'list', data: models }); }); -// Export resolveModelId for testing -export { aiProxyRoutes, resolveModelId }; +// Export for testing +export { aiProxyRoutes, isAnthropicModel, resolveModelId }; diff --git a/apps/api/src/routes/auth.ts b/apps/api/src/routes/auth.ts index c2e6a99d3..b83bb783a 100644 --- a/apps/api/src/routes/auth.ts +++ b/apps/api/src/routes/auth.ts @@ -4,6 +4,7 @@ import { createAuth } from '../auth'; import type { Env } from '../env'; import { log } from '../lib/logger'; import { errors } from '../middleware/error'; +import { maybeAttachTrialClaimCookie } from '../services/trial/oauth-hook'; const authRoutes = new Hono<{ Bindings: Env }>(); @@ -24,7 +25,10 @@ authRoutes.on(['GET', 'POST'], '/*', async (c) => { const body = await response.clone().text(); log.error('auth.better_auth_error', { status: response.status, body: body || '(empty body)' }); } - return response; + // Trial claim hook: if this is a successful OAuth callback and the user + // has an active trial fingerprint cookie, attach a claim cookie + redirect + // to the trial landing page. No-op when no fingerprint is present. + return await maybeAttachTrialClaimCookie(c.env, c.req.raw, response); } catch (err) { log.error('auth.better_auth_exception', { error: err instanceof Error ? err.message : String(err) }); return c.json({ error: 'AUTH_ERROR', message: 'Internal auth error' }, 500); diff --git a/apps/api/src/routes/mcp/idea-tools.ts b/apps/api/src/routes/mcp/idea-tools.ts index cbb2a173f..b1424fba0 100644 --- a/apps/api/src/routes/mcp/idea-tools.ts +++ b/apps/api/src/routes/mcp/idea-tools.ts @@ -247,6 +247,22 @@ export async function handleCreateIdea( contentLength: content?.length ?? 0, }); + // Trial bridge — surface freshly-created ideas as suggestion chips on the + // /try/:trialId page via the SSE stream. Non-trial projects no-op after a + // single KV lookup. + try { + const { bridgeIdeaCreated } = await import('../../services/trial/bridge'); + await bridgeIdeaCreated( + env, + tokenData.projectId, + ideaId, + title, + (content ?? '').slice(0, 280), + ); + } catch { + // Bridge errors are logged inside the helper; never block MCP. + } + return jsonRpcSuccess(requestId, { content: [{ type: 'text', diff --git a/apps/api/src/routes/mcp/knowledge-tools.ts b/apps/api/src/routes/mcp/knowledge-tools.ts index 0c155ba27..f96bfdc15 100644 --- a/apps/api/src/routes/mcp/knowledge-tools.ts +++ b/apps/api/src/routes/mcp/knowledge-tools.ts @@ -75,6 +75,16 @@ export async function handleAddKnowledge( env, tokenData.projectId, entityId, observation, confidence, sourceType, sessionId, ); + // Trial bridge — if this project is backing an anonymous trial, fan the + // observation out as a `trial.knowledge` SSE event. Non-trial projects + // short-circuit after a single KV lookup inside the helper. + try { + const { bridgeKnowledgeAdded } = await import('../../services/trial/bridge'); + await bridgeKnowledgeAdded(env, tokenData.projectId, entityName, observation); + } catch { + // Bridge errors are already logged inside the helper; never block MCP. + } + return jsonRpcSuccess(requestId, { content: [{ type: 'text', text: JSON.stringify({ added: true, diff --git a/apps/api/src/routes/trial/claim.ts b/apps/api/src/routes/trial/claim.ts new file mode 100644 index 000000000..b81a53152 --- /dev/null +++ b/apps/api/src/routes/trial/claim.ts @@ -0,0 +1,141 @@ +/** + * POST /api/trial/claim — claim an anonymous trial project after GitHub OAuth. + * + * Flow: + * 1. Require auth (session from BetterAuth cookie). + * 2. Read + verify the HMAC-signed `sam_trial_claim` cookie. + * 3. Verify the trial record exists, is unclaimed, and is not expired. + * 4. Atomically re-parent `projects.user_id` from the sentinel anonymous user + * to the authenticated user — guarded by a `WHERE user_id = sentinel` + * precondition so a double-claim attempt cannot hijack a project that has + * already been taken. + * 5. Clear the claim cookie (Max-Age=0). + * 6. Mark the KV record as `claimed: true` (best-effort). + */ + +import { + TRIAL_ANONYMOUS_USER_ID, + TRIAL_COOKIE_CLAIM_NAME, + type TrialClaimResponse, +} from '@simple-agent-manager/shared'; +import { and, eq } from 'drizzle-orm'; +import { drizzle } from 'drizzle-orm/d1'; +import { Hono } from 'hono'; + +import * as schema from '../../db/schema'; +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { getUserId, requireAuth } from '../../middleware/auth'; +import { errors } from '../../middleware/error'; +import { clearClaimCookie, verifyClaimToken } from '../../services/trial/cookies'; +import { markTrialClaimed, readTrial } from '../../services/trial/trial-store'; + +const claimRoutes = new Hono<{ Bindings: Env }>(); + +claimRoutes.post('/claim', requireAuth(), async (c) => { + const userId = getUserId(c); + + const secret = c.env.TRIAL_CLAIM_TOKEN_SECRET; + if (!secret) { + log.error('trial_claim.secret_unset'); + throw errors.internal('Trial auth secret is not configured'); + } + + const cookieHeader = c.req.header('cookie') ?? ''; + const claimCookie = parseCookie(cookieHeader, TRIAL_COOKIE_CLAIM_NAME); + if (!claimCookie) { + throw errors.badRequest('Missing trial claim cookie'); + } + + const verified = await verifyClaimToken(claimCookie, secret); + if (!verified.ok) { + log.warn('trial_claim.cookie_rejected', { reason: verified.reason, userId }); + throw errors.badRequest(`Claim cookie rejected: ${verified.reason}`); + } + + const { trialId, projectId } = verified.payload; + + // Trial record must still exist + const record = await readTrial(c.env, trialId); + if (!record) { + throw errors.notFound('Trial'); + } + if (record.projectId !== projectId) { + // Cookie-embedded projectId disagrees with KV — reject to avoid confused-deputy. + log.warn('trial_claim.projectid_mismatch', { + trialId, + cookieProjectId: projectId, + recordProjectId: record.projectId, + }); + throw errors.badRequest('Claim cookie does not match trial record'); + } + if (record.claimed) { + throw errors.conflict('Trial has already been claimed'); + } + + const db = drizzle(c.env.DATABASE, { schema }); + + // Atomic re-parent — the AND clause is the precondition that protects against + // double-claim / race conditions. If anyone else already re-parented this + // project, `update()` returns 0 rows and we surface a 409. + const anonymousUserId = c.env.TRIAL_ANONYMOUS_USER_ID ?? TRIAL_ANONYMOUS_USER_ID; + const updateResult = await db + .update(schema.projects) + .set({ userId, updatedAt: new Date().toISOString() }) + .where( + and( + eq(schema.projects.id, projectId), + eq(schema.projects.userId, anonymousUserId) + ) + ) + .run(); + + // D1 meta.changes: canonical way to detect how many rows the update hit + const meta = (updateResult as unknown as { meta?: { changes?: number } }).meta; + const changes = meta?.changes ?? 0; + if (changes === 0) { + // Project either doesn't exist or was already claimed by a different user + log.warn('trial_claim.reparent_no_rows', { trialId, projectId, userId }); + throw errors.conflict('Trial project is no longer available to claim'); + } + + // Mark trial as claimed in KV (best-effort — the D1 write is the source of truth) + try { + await markTrialClaimed(c.env, trialId); + } catch (err) { + log.warn('trial_claim.mark_claimed_failed', { + trialId, + error: err instanceof Error ? err.message : String(err), + }); + } + + const claimedAt = Date.now(); + log.info('trial_claim.success', { trialId, projectId, userId, claimedAt }); + + // Clear the claim cookie. Domain attribute MUST match what was set when the + // cookie was issued (`.BASE_DOMAIN`), otherwise the browser treats them as + // different cookies and the original is never deleted — enabling replay. + const cookieDomain = c.env.BASE_DOMAIN ? `.${c.env.BASE_DOMAIN}` : undefined; + const response: TrialClaimResponse = { projectId, claimedAt }; + c.header('Set-Cookie', clearClaimCookie({ domain: cookieDomain })); + return c.json(response, 200); +}); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function parseCookie(header: string, name: string): string | null { + if (!header) return null; + const parts = header.split(/;\s*/); + for (const part of parts) { + const eq = part.indexOf('='); + if (eq === -1) continue; + if (part.slice(0, eq) === name) { + return decodeURIComponent(part.slice(eq + 1)); + } + } + return null; +} + +export { claimRoutes }; diff --git a/apps/api/src/routes/trial/create.ts b/apps/api/src/routes/trial/create.ts new file mode 100644 index 000000000..249c69aeb --- /dev/null +++ b/apps/api/src/routes/trial/create.ts @@ -0,0 +1,502 @@ +/** + * POST /api/trial/create — create a new anonymous trial. + * + * Flow: + * 1. Validate the request body (Valibot — Wave-0 schema). + * 2. Read the kill switch; return `trials_disabled` when the KV flag is off + * or unreachable (fail-closed — see services/trial/kill-switch.ts). + * 3. Canonicalise the GitHub repo URL and probe `api.github.com/repos/...` + * to reject private, 404, or oversized (> TRIAL_REPO_MAX_KB) repositories. + * 4. Ask TrialCounter (global singleton DO) to allocate a slot for the + * current UTC month. `cap_exceeded` -> 429 with `waitlistResetsAt`. + * 5. Mint a trialId + UUID fingerprint, persist the trial row (status=pending), + * and issue the signed cookies (fingerprint 7d, claim 48h). + * 6. Return { trialId, projectId?, eventsUrl, expiresAt } — projectId is + * populated by the Track-B orchestrator and is absent until provisioned. + * + * This route is the Wave-1 Track-A hand-off point. The SSE companion + * (`GET /api/trial/events`) and the project/provisioning flow are owned by + * other tracks and consume the trial row this route creates. + */ +import { + TRIAL_COOKIE_FINGERPRINT_NAME, + type TrialCreateError, + TrialCreateRequestSchema, + type TrialCreateResponse, + type TrialErrorCode, +} from '@simple-agent-manager/shared'; +import { eq } from 'drizzle-orm'; +import { drizzle } from 'drizzle-orm/d1'; +import { Hono } from 'hono'; +import * as v from 'valibot'; + +import * as schema from '../../db/schema'; +import type { TrialCounterTryIncrementResult } from '../../durable-objects/trial-counter'; +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { rateLimitTrialCreate } from '../../middleware/rate-limit'; +import { + buildClaimCookie, + buildFingerprintCookie, + DEFAULT_TRIAL_CLAIM_TTL_MS, + signClaimToken, + signFingerprint, + type TrialClaimPayload, + verifyFingerprint, +} from '../../services/trial/cookies'; +import { emitGithubKnowledgeEvents } from '../../services/trial/github-knowledge'; +import { + currentMonthKey, + getTrialCounterStub, + nextMonthResetDate, + parseGithubRepoUrl, + resolveGithubTimeoutMs, + resolveMonthlyCap, + resolveRepoMaxKb, + resolveWorkspaceTtlMs, +} from '../../services/trial/helpers'; +import { isTrialsEnabled } from '../../services/trial/kill-switch'; +import { writeTrial } from '../../services/trial/trial-store'; + +const createRoutes = new Hono<{ Bindings: Env }>(); + +// Exposed for tests (and for other tracks that want to probe GitHub the same way). +export interface GithubRepoProbe { + ok: true; + sizeKb: number; + private: boolean; +} + +export interface GithubRepoProbeError { + ok: false; + reason: Extract; +} + +/** + * Probe `https://api.github.com/repos/{owner}/{name}` to verify that the + * repository is public, exists, and fits within the trial size budget. + * + * A 404 -> repo_not_found. `private: true` -> repo_private. size > maxKb + * -> repo_too_large. Any other non-2xx (rate-limit, 5xx, network error) + * is treated as repo_not_found so the user isn't left without a signal. + */ +export async function probeGithubRepo( + owner: string, + name: string, + opts: { maxKb: number; timeoutMs: number; fetchFn?: typeof fetch } +): Promise { + const fetchFn = opts.fetchFn ?? fetch; + const ac = new AbortController(); + const timer = setTimeout(() => ac.abort(), opts.timeoutMs); + try { + const resp = await fetchFn( + `https://api.github.com/repos/${owner}/${name}`, + { + // GitHub REST API requires a UA on unauthenticated requests. + headers: { + accept: 'application/vnd.github+json', + 'user-agent': 'sam-trial-onboarding', + }, + signal: ac.signal, + } + ); + if (resp.status === 404) return { ok: false, reason: 'repo_not_found' }; + if (!resp.ok) return { ok: false, reason: 'repo_not_found' }; + const body = (await resp.json()) as { + private?: boolean; + size?: number; + }; + if (body.private === true) return { ok: false, reason: 'repo_private' }; + const sizeKb = Number(body.size ?? 0); + if (sizeKb > opts.maxKb) return { ok: false, reason: 'repo_too_large' }; + return { ok: true, sizeKb, private: false }; + } catch { + return { ok: false, reason: 'repo_not_found' }; + } finally { + clearTimeout(timer); + } +} + +function errorResponse( + code: TrialErrorCode, + message: string, + status: number, + extra?: Partial +): Response { + const body: TrialCreateError = { error: code, message, ...extra }; + return new Response(JSON.stringify(body), { + status, + headers: { 'content-type': 'application/json' }, + }); +} + +function readCookie(header: string | null, name: string): string | null { + if (!header) return null; + const parts = header.split(/;\s*/); + for (const part of parts) { + const eq = part.indexOf('='); + if (eq < 0) continue; + if (part.slice(0, eq) === name) return part.slice(eq + 1); + } + return null; +} + +function randomId(prefix: string): string { + // crypto.randomUUID is available in Workers/modern Node. + return `${prefix}_${crypto.randomUUID().replace(/-/g, '')}`; +} + +// Per-IP rate limit (outermost gate, before kill-switch + DO dispatch). +// Default: 10 trial creates per IP per hour (see RATE_LIMIT_TRIAL_CREATE). +createRoutes.use('/create', async (c, next) => { + const limiter = rateLimitTrialCreate(c.env); + return limiter(c, next); +}); + +createRoutes.post('/create', async (c) => { + const env = c.env; + const now = Date.now(); + + // -- 1. Validate body ------------------------------------------------------ + let body: unknown; + try { + body = await c.req.json(); + } catch { + return errorResponse('invalid_url', 'Request body must be valid JSON', 400); + } + const parsed = v.safeParse(TrialCreateRequestSchema, body); + if (!parsed.success) { + const issue = parsed.issues[0]; + return errorResponse( + 'invalid_url', + issue?.message ?? 'Must be a public GitHub repository URL', + 400 + ); + } + + const repo = parseGithubRepoUrl(parsed.output.repoUrl); + if (!repo) { + return errorResponse( + 'invalid_url', + 'Must be a public GitHub repository URL', + 400 + ); + } + + // -- 2. Kill switch -------------------------------------------------------- + const enabled = await isTrialsEnabled(env, now); + if (!enabled) { + return errorResponse( + 'trials_disabled', + 'Trial onboarding is currently disabled', + 503 + ); + } + + // Secret must be present whenever trials are enabled. + const secret = env.TRIAL_CLAIM_TOKEN_SECRET; + if (!secret) { + log.error('trial.create.missing_secret', {}); + return errorResponse( + 'trials_disabled', + 'Trial onboarding is misconfigured', + 503 + ); + } + + // -- 3. GitHub repo probe -------------------------------------------------- + const probe = await probeGithubRepo(repo.owner, repo.name, { + maxKb: resolveRepoMaxKb(env), + timeoutMs: resolveGithubTimeoutMs(env), + }); + if (!probe.ok) { + const status = + probe.reason === 'repo_not_found' + ? 404 + : probe.reason === 'repo_private' + ? 403 + : 413; + return errorResponse( + probe.reason, + { + repo_not_found: 'Repository not found or not publicly accessible', + repo_private: 'Repository is private', + repo_too_large: 'Repository exceeds the trial size limit', + }[probe.reason], + status + ); + } + + // -- 4. Allocate a slot on the TrialCounter DO ----------------------------- + const monthKey = currentMonthKey(now); + const cap = resolveMonthlyCap(env); + const counter = getTrialCounterStub(env); + let slot: TrialCounterTryIncrementResult; + try { + // RPC path — typed method call on the DO stub. + slot = await ( + counter as unknown as { + tryIncrement( + monthKey: string, + cap: number + ): Promise; + } + ).tryIncrement(monthKey, cap); + } catch (err) { + log.error('trial.create.counter_failed', { + error: err instanceof Error ? err.message : String(err), + }); + // Fail closed if the counter DO is unreachable. + return errorResponse( + 'trials_disabled', + 'Trial slot allocation unavailable — please retry shortly', + 503 + ); + } + + if (!slot.allowed) { + return errorResponse( + 'cap_exceeded', + 'Monthly trial cap reached — join the waitlist to be notified', + 429, + { waitlistResetsAt: nextMonthResetDate(now) } + ); + } + + // -- 5. Persist the trial row --------------------------------------------- + const trialId = randomId('trial'); + const ttlMs = resolveWorkspaceTtlMs(env); + const expiresAt = now + ttlMs; + + // Determine (or mint) the visitor fingerprint. The same anonymous visitor + // reuses the same fingerprint across multiple trials within the 7d cookie + // lifetime, which makes support/abuse investigations tractable. + // + // SECURITY: the fingerprint cookie MUST be HMAC-verified before we trust its + // UUID. If we only split on the last dot and accept whatever's to the left, + // an attacker who learns a victim's fingerprint UUID (e.g. from logs, the + // victim's browser, or a prior trial row) can forge `.` + // and overwrite the `trial-by-fingerprint:` KV index to point at + // their own trial — the next OAuth-hook lookup then redirects the victim to + // the attacker's trial. Always call verifyFingerprint() and fall back to a + // fresh UUID on invalid/missing signature. + const existingFp = readCookie( + c.req.header('cookie') ?? null, + TRIAL_COOKIE_FINGERPRINT_NAME + ); + let fingerprintUuid: string | null = null; + if (existingFp) { + fingerprintUuid = await verifyFingerprint(existingFp, secret); + } + if (!fingerprintUuid) fingerprintUuid = crypto.randomUUID(); + + const db = drizzle(env.DATABASE, { schema }); + try { + await db.insert(schema.trials).values({ + id: trialId, + fingerprint: fingerprintUuid, + repoUrl: repo.canonical, + repoOwner: repo.owner, + repoName: repo.name, + monthKey, + status: 'pending', + projectId: null, + claimedByUserId: null, + createdAt: now, + expiresAt, + claimedAt: null, + errorCode: null, + errorMessage: null, + }); + } catch (err) { + log.error('trial.create.insert_failed', { + trialId, + error: err instanceof Error ? err.message : String(err), + }); + // Release the allocated counter slot so a failed insert doesn't burn + // a monthly slot for the user. + try { + await ( + counter as unknown as { decrement(monthKey: string): Promise } + ).decrement(monthKey); + } catch (decErr) { + log.error('trial.create.counter_decrement_failed', { + trialId, + error: decErr instanceof Error ? decErr.message : String(decErr), + }); + } + return errorResponse( + 'trials_disabled', + 'Trial creation failed — please retry shortly', + 500 + ); + } + + // Mirror the trial record into KV so Track B readers (SSE events, claim, + // orchestrator) can find it. D1 remains the durable store of record; KV is + // the fast-path lookup by trialId/fingerprint. projectId is '' here — Track B + // rewrites the record once the project row exists. + try { + await writeTrial(env, { + trialId, + projectId: '', + fingerprint: fingerprintUuid, + workspaceId: null, + repoUrl: repo.canonical, + createdAt: now, + expiresAt, + claimed: false, + }); + } catch (err) { + log.error('trial.create.kv_write_failed', { + trialId, + error: err instanceof Error ? err.message : String(err), + }); + // Release the counter slot + roll back D1 so we don't leak state. + try { + await db.delete(schema.trials).where(eq(schema.trials.id, trialId)).run(); + } catch { + // best-effort rollback — D1 row is already written, and the cron sweep + // will eventually expire it via expiresAt. + } + try { + await ( + counter as unknown as { decrement(monthKey: string): Promise } + ).decrement(monthKey); + } catch { + // best-effort + } + return errorResponse( + 'trials_disabled', + 'Trial creation failed — please retry shortly', + 500 + ); + } + + // -- 6. Issue cookies + return -------------------------------------------- + const fingerprintSigned = await signFingerprint(fingerprintUuid, secret); + const claimPayload: TrialClaimPayload = { + trialId, + // Placeholder — the SSE orchestrator will re-issue a fresh claim cookie + // once the project row exists. Until then the client has a signed token + // bound to the trialId alone, which is sufficient for claim validation. + projectId: '', + issuedAt: now, + expiresAt: now + DEFAULT_TRIAL_CLAIM_TTL_MS, + }; + const claimSigned = await signClaimToken(claimPayload, secret); + + const cookieDomain = env.BASE_DOMAIN ? `.${env.BASE_DOMAIN}` : undefined; + const secure = true; // HTTPS-only on every Worker deployment + + const respBody: TrialCreateResponse = { + trialId, + projectId: '', // populated once Track-B provisions the project row + eventsUrl: `/api/trial/${encodeURIComponent(trialId)}/events`, + expiresAt, + }; + + const headers = new Headers({ 'content-type': 'application/json' }); + headers.append( + 'set-cookie', + buildFingerprintCookie(fingerprintSigned, { + secure, + domain: cookieDomain, + }) + ); + headers.append( + 'set-cookie', + buildClaimCookie(claimSigned, { secure, domain: cookieDomain }) + ); + + log.info('trial.create.ok', { + trialId, + monthKey, + slotCount: slot.count, + cap, + repo: repo.canonical, + // fingerprint deliberately omitted from structured logs — it's PII-adjacent + }); + + // -- 7. Fire-and-forget orchestrator dispatch + GitHub knowledge probes ---- + // Both tasks run via waitUntil so the HTTP response returns immediately + // (<100ms) while provisioning and the fast knowledge probes continue in the + // background. Errors are swallowed inside the helpers so a failure here + // never causes the response to hang. + // + // c.executionCtx may be absent in unit tests that construct Hono apps + // without a real Worker executionCtx. Hono throws on access when it's + // missing, so we guard with try/catch rather than a simple null check. + let exec: ExecutionContext | undefined; + try { + exec = c.executionCtx; + } catch { + exec = undefined; + } + log.info('trial.create.dispatch_begin', { + trialId, + execCtxPresent: !!exec, + }); + const orchestratorTask = (async () => { + log.info('trial.create.orchestrator_task.enter', { trialId }); + try { + const id = env.TRIAL_ORCHESTRATOR.idFromName(trialId); + const stub = env.TRIAL_ORCHESTRATOR.get(id); + log.info('trial.create.orchestrator_task.stub_ready', { trialId }); + await ( + stub as unknown as { + start(input: { + trialId: string; + repoUrl: string; + repoOwner: string; + repoName: string; + }): Promise; + } + ).start({ + trialId, + repoUrl: repo.canonical, + repoOwner: repo.owner, + repoName: repo.name, + }); + log.info('trial.create.orchestrator_task.start_returned', { trialId }); + } catch (err) { + log.error('trial.create.orchestrator_dispatch_failed', { + trialId, + error: err instanceof Error ? err.message : String(err), + stack: err instanceof Error ? err.stack : undefined, + }); + } + })(); + const knowledgeTask = (async () => { + log.info('trial.create.knowledge_task.enter', { trialId }); + try { + await emitGithubKnowledgeEvents(env, trialId, { + owner: repo.owner, + name: repo.name, + }); + log.info('trial.create.knowledge_task.done', { trialId }); + } catch (err) { + log.error('trial.create.knowledge_task.failed', { + trialId, + error: err instanceof Error ? err.message : String(err), + stack: err instanceof Error ? err.stack : undefined, + }); + } + })(); + if (exec) { + exec.waitUntil(orchestratorTask); + exec.waitUntil(knowledgeTask); + log.info('trial.create.waitUntil_registered', { trialId }); + } else { + // Prevent unhandled-rejection warnings in tests without swallowing logs. + void orchestratorTask.catch(() => {}); + void knowledgeTask.catch(() => {}); + } + + return new Response(JSON.stringify(respBody), { + status: 201, + headers, + }); +}); + +export { createRoutes }; diff --git a/apps/api/src/routes/trial/events.ts b/apps/api/src/routes/trial/events.ts new file mode 100644 index 000000000..b7dec6656 --- /dev/null +++ b/apps/api/src/routes/trial/events.ts @@ -0,0 +1,231 @@ +/** + * GET /api/trial/:trialId/events — Server-Sent Events stream for a trial. + * + * Auth model: fingerprint-cookie-based. The endpoint rejects any request whose + * `sam_trial_fingerprint` cookie does NOT cryptographically bind to the trial's + * recorded fingerprint. This fails closed — no fingerprint cookie => 401. + * + * Stream semantics: + * - `content-type: text/event-stream` + `cache-control: no-cache` + * - Comment-only heartbeats every TRIAL_SSE_HEARTBEAT_MS (default 15s) + * - Events are sourced from the per-trial TrialEventBus DO via long-poll + * - Terminates cleanly after `trial.ready` / `trial.error` is observed + * - Hard duration cap (TRIAL_SSE_MAX_DURATION_MS, default 30 min) to avoid + * runaway connections + */ + +import type { TrialEvent } from '@simple-agent-manager/shared'; +import { TRIAL_COOKIE_FINGERPRINT_NAME } from '@simple-agent-manager/shared'; +import { Hono } from 'hono'; + +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { errors } from '../../middleware/error'; +import { + checkRateLimit, + createRateLimitKey, + getCurrentWindowStart, + getRateLimit, +} from '../../middleware/rate-limit'; +import { verifyFingerprint } from '../../services/trial/cookies'; +import { readTrial } from '../../services/trial/trial-store'; + +const eventsRoutes = new Hono<{ Bindings: Env }>(); + +const DEFAULT_HEARTBEAT_MS = 15_000; +const DEFAULT_POLL_TIMEOUT_MS = 15_000; +const DEFAULT_MAX_DURATION_MS = 30 * 60 * 1000; // 30 min +const SSE_RATE_LIMIT_WINDOW_SECONDS = 300; // 5-minute window + +eventsRoutes.get('/:trialId/events', async (c) => { + const trialId = c.req.param('trialId'); + if (!trialId) throw errors.badRequest('trialId is required'); + + // Rate limit: per-IP to prevent SSE connection storms from a single source. + // CF-Connecting-IP is always present on Cloudflare Workers for real traffic; + // X-Forwarded-For covers reverse-proxy setups. If neither is present, reject + // rather than sharing a single "unknown" bucket across all headerless clients. + const clientIp = c.req.header('CF-Connecting-IP') + ?? c.req.header('X-Forwarded-For')?.split(',')[0]?.trim() + ?? null; + if (!clientIp) { + log.warn('trial_events.missing_client_ip', { trialId }); + return c.json({ error: 'Unable to determine client IP for rate limiting.' }, 400); + } + // KV-based rate limiting is not atomic (no CAS primitive), so concurrent + // requests from the same IP can overshoot by ~1. This is acceptable for SSE + // storm prevention — exact enforcement is not required. For strict limits + // (e.g. credential rotation), the codebase uses DO-based counters instead + // (see CodexRefreshLock). + const sseLimit = getRateLimit(c.env, 'TRIAL_SSE'); + const windowStart = getCurrentWindowStart(SSE_RATE_LIMIT_WINDOW_SECONDS); + const rateLimitKey = createRateLimitKey('trial-sse', clientIp, windowStart); + const { allowed } = await checkRateLimit(c.env.KV, rateLimitKey, sseLimit, SSE_RATE_LIMIT_WINDOW_SECONDS); + if (!allowed) { + return c.json({ error: 'Too many SSE connections. Please try again later.' }, 429); + } + + // Resolve trial record + const record = await readTrial(c.env, trialId); + if (!record) throw errors.notFound('Trial'); + + // Fingerprint auth — fail closed + const secret = c.env.TRIAL_CLAIM_TOKEN_SECRET; + if (!secret) { + log.error('trial_events.secret_unset', { trialId }); + throw errors.internal('Trial auth secret is not configured'); + } + + const cookieHeader = c.req.header('cookie') ?? ''; + const fingerprintCookie = parseCookie(cookieHeader, TRIAL_COOKIE_FINGERPRINT_NAME); + if (!fingerprintCookie) { + throw errors.unauthorized('Missing trial fingerprint'); + } + const uuid = await verifyFingerprint(fingerprintCookie, secret); + if (!uuid || uuid !== record.fingerprint) { + throw errors.unauthorized('Trial fingerprint does not match this trial'); + } + + // Build SSE stream + const heartbeatMs = parseIntSafe(c.env.TRIAL_SSE_HEARTBEAT_MS, DEFAULT_HEARTBEAT_MS); + const pollTimeoutMs = parseIntSafe(c.env.TRIAL_SSE_POLL_TIMEOUT_MS, DEFAULT_POLL_TIMEOUT_MS); + const maxDurationMs = parseIntSafe(c.env.TRIAL_SSE_MAX_DURATION_MS, DEFAULT_MAX_DURATION_MS); + + const encoder = new TextEncoder(); + const busStub = c.env.TRIAL_EVENT_BUS.get(c.env.TRIAL_EVENT_BUS.idFromName(trialId)); + + const stream = new ReadableStream({ + async start(controller) { + let closed = false; + let cursor = 0; + const startedAt = Date.now(); + + const enqueue = (bytes: Uint8Array) => { + if (closed) return; + try { + controller.enqueue(bytes); + } catch { + closed = true; + } + }; + + // Initial comment — flushes headers and helps with proxies. + enqueue(encoder.encode(': connected\n\n')); + + // Heartbeat timer — comment frames are ignored by EventSource but keep + // the TCP connection alive across CDN buffering. + const heartbeat = setInterval(() => { + enqueue(encoder.encode(`: heartbeat ${Date.now()}\n\n`)); + }, heartbeatMs); + + try { + while (!closed && Date.now() - startedAt < maxDurationMs) { + let pollResp: Response; + try { + pollResp = await busStub.fetch( + `https://trial-event-bus/poll?cursor=${cursor}&timeoutMs=${pollTimeoutMs}`, + { method: 'GET' } + ); + } catch (err) { + log.warn('trial_events.poll_error', { + trialId, + error: err instanceof Error ? err.message : String(err), + }); + // Brief back-off so we don't spin on DO errors. + await sleep(1000); + continue; + } + + if (!pollResp.ok) { + enqueue(encoder.encode(formatSse({ + type: 'trial.error', + error: 'invalid_url', + message: `Event bus poll failed: ${pollResp.status}`, + at: Date.now(), + }))); + break; + } + + const data = (await pollResp.json()) as { + events: { cursor: number; event: TrialEvent }[]; + cursor: number; + closed: boolean; + }; + + for (const { cursor: c2, event } of data.events) { + enqueue(encoder.encode(formatSse(event))); + cursor = c2; + } + + if (data.closed) break; + } + } finally { + clearInterval(heartbeat); + if (!closed) { + closed = true; + try { + controller.close(); + } catch { + // already closed + } + } + } + }, + cancel() { + // client disconnected — nothing to clean up beyond the start() finally. + }, + }); + + return new Response(stream, { + status: 200, + headers: { + 'content-type': 'text/event-stream; charset=utf-8', + 'cache-control': 'no-cache, no-transform', + connection: 'keep-alive', + 'x-accel-buffering': 'no', + }, + }); +}); + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function parseCookie(header: string, name: string): string | null { + if (!header) return null; + const parts = header.split(/;\s*/); + for (const part of parts) { + const eq = part.indexOf('='); + if (eq === -1) continue; + if (part.slice(0, eq) === name) { + return decodeURIComponent(part.slice(eq + 1)); + } + } + return null; +} + +function parseIntSafe(value: string | undefined, fallback: number): number { + if (!value) return fallback; + const n = Number.parseInt(value, 10); + return Number.isFinite(n) && n > 0 ? n : fallback; +} + +export function formatSse(data: unknown): string { + // Emit as the default ("message") SSE event so that EventSource consumers + // receive the payload via `source.onmessage`. Using a named `event:` field + // would require the client to register `addEventListener(, ...)` for + // every TrialEvent variant, and `onmessage` would silently never fire — + // the exact "zero events on staging" symptom that motivated this fix. + // + // The TrialEvent JSON payload itself carries a `type` discriminator, so + // dropping the `event:` line loses no information. Data is JSON-encoded + // which already escapes newlines, preventing SSE-frame injection. + const json = JSON.stringify(data); + return `data: ${json}\n\n`; +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +export { eventsRoutes }; diff --git a/apps/api/src/routes/trial/index.ts b/apps/api/src/routes/trial/index.ts new file mode 100644 index 000000000..bfd60adc0 --- /dev/null +++ b/apps/api/src/routes/trial/index.ts @@ -0,0 +1,29 @@ +import { Hono } from 'hono'; + +import type { Env } from '../../env'; +import { claimRoutes } from './claim'; +import { createRoutes } from './create'; +import { eventsRoutes } from './events'; +import { statusRoutes } from './status'; +import { waitlistRoutes } from './waitlist'; + +/** + * Trial onboarding routes — mounted at `/api/trial`. + * + * Wave-0 stubs only; each handler returns 501. Wave-1 wires real behaviour. + * The five subrouters are kept in separate files so the parallel Wave-1 + * tracks can evolve them without merge conflicts. + * + * Note: per-route middleware is used inside each subrouter (not a wildcard + * middleware on this parent) — see `.claude/rules/06-api-patterns.md` on + * Hono middleware scope leakage. + */ +const trialOnboardingRoutes = new Hono<{ Bindings: Env }>(); + +trialOnboardingRoutes.route('/', createRoutes); +trialOnboardingRoutes.route('/', eventsRoutes); +trialOnboardingRoutes.route('/', claimRoutes); +trialOnboardingRoutes.route('/', waitlistRoutes); +trialOnboardingRoutes.route('/', statusRoutes); + +export { trialOnboardingRoutes }; diff --git a/apps/api/src/routes/trial/status.ts b/apps/api/src/routes/trial/status.ts new file mode 100644 index 000000000..869c95fc5 --- /dev/null +++ b/apps/api/src/routes/trial/status.ts @@ -0,0 +1,67 @@ +/** + * GET /api/trial/status — public availability snapshot (no auth). + * + * Returns: + * - `enabled`: kill-switch state (fail-closed on KV error) + * - `remaining`: cap - current count for the UTC month (0 when at/over cap) + * - `resetsAt`: ISO date of the next UTC month's first day + * + * Callers: the unauthenticated landing page and the "join waitlist" CTA. + */ +import type { TrialStatusResponse } from '@simple-agent-manager/shared'; +import { Hono } from 'hono'; + +import type { TrialCounterState } from '../../durable-objects/trial-counter'; +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { + currentMonthKey, + getTrialCounterStub, + nextMonthResetDate, + resolveMonthlyCap, +} from '../../services/trial/helpers'; +import { isTrialsEnabled } from '../../services/trial/kill-switch'; + +const statusRoutes = new Hono<{ Bindings: Env }>(); + +statusRoutes.get('/status', async (c) => { + const env = c.env; + const now = Date.now(); + const enabled = await isTrialsEnabled(env, now); + const cap = resolveMonthlyCap(env); + const monthKey = currentMonthKey(now); + const resetsAt = nextMonthResetDate(now); + + let count = 0; + try { + const stub = getTrialCounterStub(env); + const state: TrialCounterState = await ( + stub as unknown as { get(monthKey: string): Promise } + ).get(monthKey); + count = state.count; + } catch (err) { + // DO unreachable -> surface enabled=false with zero remaining so the + // client falls through to the waitlist CTA. This mirrors the fail-closed + // posture of the kill-switch. + log.error('trial.status.counter_failed', { + error: err instanceof Error ? err.message : String(err), + }); + const resp: TrialStatusResponse = { + enabled: false, + remaining: 0, + resetsAt, + }; + return c.json(resp, 200); + } + + const remaining = cap > 0 ? Math.max(0, cap - count) : Number.MAX_SAFE_INTEGER; + + const resp: TrialStatusResponse = { + enabled, + remaining, + resetsAt, + }; + return c.json(resp, 200); +}); + +export { statusRoutes }; diff --git a/apps/api/src/routes/trial/waitlist.ts b/apps/api/src/routes/trial/waitlist.ts new file mode 100644 index 000000000..0e5f28617 --- /dev/null +++ b/apps/api/src/routes/trial/waitlist.ts @@ -0,0 +1,88 @@ +/** + * POST /api/trial/waitlist — queue for notification when the monthly cap + * resets. + * + * Public endpoint (no auth). Upserts a row into `trial_waitlist` using the + * UNIQUE(email, reset_date) index so repeated submits within the same + * reset window are idempotent (returned as `{ queued: true }` without + * creating duplicates). The monthly notifier cron flips `notified_at` + * when the window opens and an email is sent. + */ +import { + TrialWaitlistRequestSchema, + type TrialWaitlistResponse, +} from '@simple-agent-manager/shared'; +import { drizzle } from 'drizzle-orm/d1'; +import { Hono } from 'hono'; +import * as v from 'valibot'; + +import * as schema from '../../db/schema'; +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { nextMonthResetDate } from '../../services/trial/helpers'; + +const waitlistRoutes = new Hono<{ Bindings: Env }>(); + +function randomId(prefix: string): string { + return `${prefix}_${crypto.randomUUID().replace(/-/g, '')}`; +} + +waitlistRoutes.post('/waitlist', async (c) => { + const env = c.env; + const now = Date.now(); + + // Parse + validate + let body: unknown; + try { + body = await c.req.json(); + } catch { + return c.json( + { error: 'BAD_REQUEST', message: 'Request body must be valid JSON' }, + 400 + ); + } + const parsed = v.safeParse(TrialWaitlistRequestSchema, body); + if (!parsed.success) { + return c.json( + { + error: 'BAD_REQUEST', + message: parsed.issues[0]?.message ?? 'Invalid email', + }, + 400 + ); + } + + // Normalise the email — UNIQUE index is case-sensitive by default; fold + // to lowercase so repeated submits with different casing dedupe cleanly. + const email = parsed.output.email.toLowerCase(); + const resetDate = nextMonthResetDate(now); + + const db = drizzle(env.DATABASE, { schema }); + try { + await db + .insert(schema.trialWaitlist) + .values({ + id: randomId('wl'), + email, + submittedAt: now, + resetDate, + notifiedAt: null, + }) + .onConflictDoNothing({ + target: [schema.trialWaitlist.email, schema.trialWaitlist.resetDate], + }); + } catch (err) { + log.error('trial.waitlist.insert_failed', { + error: err instanceof Error ? err.message : String(err), + }); + return c.json( + { error: 'INTERNAL_ERROR', message: 'Failed to queue waitlist entry' }, + 500 + ); + } + + const resp: TrialWaitlistResponse = { queued: true, resetsAt: resetDate }; + return c.json(resp, 200); +}); + +export { waitlistRoutes }; diff --git a/apps/api/src/routes/workspaces/runtime.ts b/apps/api/src/routes/workspaces/runtime.ts index 8bd165b69..600e0199d 100644 --- a/apps/api/src/routes/workspaces/runtime.ts +++ b/apps/api/src/routes/workspaces/runtime.ts @@ -1,4 +1,4 @@ -import { type BootstrapTokenData, DEFAULT_AI_PROXY_MODEL, getAgentDefinition, isValidAgentType } from '@simple-agent-manager/shared'; +import { AI_PROXY_DEFAULT_MODEL_KV_KEY, type AIProxyConfig, type BootstrapTokenData, DEFAULT_AI_PROXY_MODEL, getAgentDefinition, isValidAgentType } from '@simple-agent-manager/shared'; import { and, eq, isNull } from 'drizzle-orm'; import { drizzle } from 'drizzle-orm/d1'; import { Hono } from 'hono'; @@ -19,6 +19,7 @@ import { persistError } from '../../services/observability'; import { resolveProjectAgentDefault } from '../../services/project-agent-defaults'; import * as projectDataService from '../../services/project-data'; import { extractScalewaySecretKey } from '../../services/provider-credentials'; +import { bridgeAgentActivity } from '../../services/trial/bridge'; import { getDecryptedAgentKey, getDecryptedCredential } from '../credentials'; import { getWorkspaceRuntimeAssets, @@ -78,7 +79,16 @@ runtimeRoutes.post('/:id/agent-key', jsonValidator(AgentTypeBodySchema), async ( if (!credentialData && body.agentType === 'opencode' && aiProxyEnabled) { const baseDomain = c.env.BASE_DOMAIN; const proxyBaseUrl = `https://api.${baseDomain}/ai/v1`; - const defaultModel = c.env.AI_PROXY_DEFAULT_MODEL ?? DEFAULT_AI_PROXY_MODEL; + + // Resolve default model: KV (admin-set) > env var > shared constant + let defaultModel = c.env.AI_PROXY_DEFAULT_MODEL ?? DEFAULT_AI_PROXY_MODEL; + try { + const kvConfig = await c.env.KV.get(AI_PROXY_DEFAULT_MODEL_KV_KEY); + if (kvConfig) { + const parsed: AIProxyConfig = JSON.parse(kvConfig); + if (parsed.defaultModel) defaultModel = parsed.defaultModel; + } + } catch { /* KV unavailable or corrupt data — use env/default */ } log.info('agent_key.ai_proxy_fallback', { workspaceId, userId: workspace.userId, proxyBaseUrl }); @@ -604,6 +614,23 @@ runtimeRoutes.post('/:id/messages', jsonValidator(MessageBatchSchema), async (c) ); } + // Fire-and-forget: pipe agent activity to the trial SSE feed (if this + // workspace belongs to a trial project). Non-trial projects short-circuit + // with a single KV lookup inside the bridge. + if (workspace.projectId && result.persisted > 0) { + c.executionCtx.waitUntil( + bridgeAgentActivity( + c.env, + workspace.projectId, + body.messages.map((m) => ({ + role: m.role, + content: m.content, + toolMetadata: m.toolMetadata ? safeParseJson(m.toolMetadata) : undefined, + })), + ), + ); + } + return c.json({ persisted: result.persisted, duplicates: result.duplicates, diff --git a/apps/api/src/scheduled/trial-expire.ts b/apps/api/src/scheduled/trial-expire.ts new file mode 100644 index 000000000..174d16eea --- /dev/null +++ b/apps/api/src/scheduled/trial-expire.ts @@ -0,0 +1,59 @@ +/** + * Cron handler: expire stale pending / ready trials. + * + * Runs on the 5-minute operational sweep. Any `trials` row whose + * status is still `pending` or `ready` and whose `expires_at` is in + * the past is transitioned to `expired`. The TrialCounter DO is NOT + * decremented — the slot was genuinely consumed for the month. + * + * Returns a summary for the `cron.completed` log message. + */ +import { and, inArray, lt } from 'drizzle-orm'; +import { drizzle } from 'drizzle-orm/d1'; + +import * as schema from '../db/schema'; +import type { Env } from '../env'; + +export interface TrialExpireResult { + expired: number; +} + +export async function runTrialExpireSweep( + env: Env, + now: number = Date.now() +): Promise { + const db = drizzle(env.DATABASE, { schema }); + + // D1 supports UPDATE ... RETURNING; we use an explicit SELECT/UPDATE pair + // so we can count rows without relying on provider-specific affected-row + // counts. + const candidates = await db + .select({ id: schema.trials.id }) + .from(schema.trials) + .where( + and( + inArray(schema.trials.status, ['pending', 'ready']), + lt(schema.trials.expiresAt, now) + ) + ) + .limit(1000); + + if (candidates.length === 0) return { expired: 0 }; + + await db + .update(schema.trials) + .set({ status: 'expired' }) + .where( + inArray( + schema.trials.id, + candidates.map((r) => r.id) + ) + ); + + // We deliberately DO NOT also `eq(status, ...)` here — the candidate + // IDs are scoped to the snapshot we just read, and overriding a concurrent + // status transition would be wrong. The subsequent cron invocation will + // re-evaluate any rows we skipped. + + return { expired: candidates.length }; +} diff --git a/apps/api/src/scheduled/trial-rollover.ts b/apps/api/src/scheduled/trial-rollover.ts new file mode 100644 index 000000000..0b45867ed --- /dev/null +++ b/apps/api/src/scheduled/trial-rollover.ts @@ -0,0 +1,63 @@ +/** + * Cron handler: monthly rollover audit. + * + * Runs on the `TRIAL_CRON_ROLLOVER_CRON` schedule (default: first of + * the month at 03:00 UTC). Two responsibilities: + * + * 1. Prune old counter rows from the TrialCounter DO's SQLite so the + * DO's storage footprint stays bounded. Retention is controlled by + * TRIAL_COUNTER_KEEP_MONTHS. + * 2. Verify that the DO's current month key matches the Worker's + * current month key; log a warning otherwise (drift indicates a + * timezone misconfiguration — the cap would be mis-applied). + * + * Returns a summary for the `cron.completed` log message. + */ +import type { Env } from '../env'; +import { log } from '../lib/logger'; +import { + currentMonthKey, + getTrialCounterStub, + resolveCounterKeepMonths, + shiftMonthKey, +} from '../services/trial/helpers'; + +export interface TrialRolloverResult { + monthKey: string; + pruned: number; +} + +interface CounterRpc { + prune(keepMonthKey: string): Promise; + get(monthKey: string): Promise<{ monthKey: string; count: number }>; +} + +export async function runTrialRolloverAudit( + env: Env, + now: number = Date.now() +): Promise { + const monthKey = currentMonthKey(now); + const keepMonths = resolveCounterKeepMonths(env); + // Keep `keepMonths` rows: the current month plus (keepMonths - 1) prior. + // Prune everything strictly older than `oldestKept`. + const oldestKept = shiftMonthKey(monthKey, -(keepMonths - 1)); + + try { + const stub = getTrialCounterStub(env) as unknown as CounterRpc; + const pruned = await stub.prune(oldestKept); + // Sanity check — fetch the current month; no action on mismatch, just log. + const state = await stub.get(monthKey); + if (state.monthKey !== monthKey) { + log.warn('trial.rollover.monthKey_drift', { + expected: monthKey, + actual: state.monthKey, + }); + } + return { monthKey, pruned }; + } catch (err) { + log.error('trial.rollover.failed', { + error: err instanceof Error ? err.message : String(err), + }); + return { monthKey, pruned: 0 }; + } +} diff --git a/apps/api/src/scheduled/trial-waitlist-cleanup.ts b/apps/api/src/scheduled/trial-waitlist-cleanup.ts new file mode 100644 index 000000000..e251a066f --- /dev/null +++ b/apps/api/src/scheduled/trial-waitlist-cleanup.ts @@ -0,0 +1,59 @@ +/** + * Cron handler: daily trial_waitlist purge. + * + * Runs on the `TRIAL_CRON_WAITLIST_CLEANUP` schedule (default: daily + * at 04:00 UTC). Deletes rows whose `notified_at` is older than + * `TRIAL_WAITLIST_PURGE_DAYS` (default: 30). Rows that have not yet + * been notified are preserved regardless of age — the monthly + * notifier cron will eventually flip `notified_at`. + * + * Returns a summary for the `cron.completed` log message. + */ +import { and, isNotNull, lt } from 'drizzle-orm'; +import { drizzle } from 'drizzle-orm/d1'; + +import * as schema from '../db/schema'; +import type { Env } from '../env'; +import { resolveWaitlistPurgeDays } from '../services/trial/helpers'; + +export interface TrialWaitlistCleanupResult { + purged: number; +} + +export async function runTrialWaitlistCleanup( + env: Env, + now: number = Date.now() +): Promise { + const db = drizzle(env.DATABASE, { schema }); + const purgeDays = resolveWaitlistPurgeDays(env); + const threshold = now - purgeDays * 24 * 60 * 60 * 1000; + + // Select first so we can report a precise count (D1 drivers vary on + // affected-row reporting; counting IDs is reliable). + const candidates = await db + .select({ id: schema.trialWaitlist.id }) + .from(schema.trialWaitlist) + .where( + and( + isNotNull(schema.trialWaitlist.notifiedAt), + lt(schema.trialWaitlist.notifiedAt, threshold) + ) + ) + .limit(5000); + + if (candidates.length === 0) return { purged: 0 }; + + // Batch delete. `inArray` here would re-import; use a loop with a single + // predicate on a stable boundary instead — simpler and avoids an extra + // 5000-element IN list. + await db + .delete(schema.trialWaitlist) + .where( + and( + isNotNull(schema.trialWaitlist.notifiedAt), + lt(schema.trialWaitlist.notifiedAt, threshold) + ) + ); + + return { purged: candidates.length }; +} diff --git a/apps/api/src/services/ai-anthropic-translate.ts b/apps/api/src/services/ai-anthropic-translate.ts new file mode 100644 index 000000000..60ec487c8 --- /dev/null +++ b/apps/api/src/services/ai-anthropic-translate.ts @@ -0,0 +1,464 @@ +/** + * OpenAI ↔ Anthropic Messages format translation for the AI proxy. + * + * OpenCode speaks OpenAI-compatible chat format. When the resolved model is + * an Anthropic model (claude-*), the proxy must translate to/from Anthropic's + * Messages API format before forwarding through the AI Gateway's /anthropic path. + */ + +// ============================================================================= +// Types +// ============================================================================= + +interface OpenAIMessage { + role: 'system' | 'user' | 'assistant' | 'tool'; + content: string | ContentPart[] | null; + name?: string; + tool_calls?: OpenAIToolCall[]; + tool_call_id?: string; +} + +interface ContentPart { + type: 'text' | 'image_url'; + text?: string; + image_url?: { url: string }; +} + +interface OpenAIToolCall { + id: string; + type: 'function'; + function: { name: string; arguments: string }; +} + +interface OpenAITool { + type: 'function'; + function: { + name: string; + description?: string; + parameters?: Record; + }; +} + +interface AnthropicMessage { + role: 'user' | 'assistant'; + content: AnthropicContentBlock[]; +} + +type AnthropicContentBlock = + | { type: 'text'; text: string } + | { type: 'tool_use'; id: string; name: string; input: unknown } + | { type: 'tool_result'; tool_use_id: string; content: string | AnthropicContentBlock[] }; + +interface AnthropicTool { + name: string; + description?: string; + input_schema: Record; +} + +interface AnthropicRequest { + model: string; + messages: AnthropicMessage[]; + system?: string; + max_tokens: number; + stream?: boolean; + tools?: AnthropicTool[]; + temperature?: number; + top_p?: number; + stop_sequences?: string[]; +} + +interface AnthropicResponse { + id: string; + type: 'message'; + role: 'assistant'; + content: AnthropicContentBlock[]; + model: string; + stop_reason: string | null; + usage: { input_tokens: number; output_tokens: number }; +} + +// ============================================================================= +// Request Translation: OpenAI → Anthropic +// ============================================================================= + +/** Default max_tokens for Anthropic (required field). */ +const DEFAULT_MAX_TOKENS = 4096; + +/** + * Translate an OpenAI chat completions request body into an Anthropic Messages API request. + */ +export function translateRequestToAnthropic( + body: Record, + modelId: string, +): AnthropicRequest { + const messages = body.messages as OpenAIMessage[]; + const systemMessages: string[] = []; + const anthropicMessages: AnthropicMessage[] = []; + + for (const msg of messages) { + if (msg.role === 'system') { + // Anthropic uses a top-level `system` field, not a system message role + const text = typeof msg.content === 'string' + ? msg.content + : (msg.content as ContentPart[] | null)?.map((p) => p.text || '').join('') || ''; + if (text) systemMessages.push(text); + continue; + } + + if (msg.role === 'user') { + const content = normalizeContent(msg.content); + anthropicMessages.push({ role: 'user', content }); + continue; + } + + if (msg.role === 'assistant') { + const blocks: AnthropicContentBlock[] = []; + // Text content + const text = typeof msg.content === 'string' + ? msg.content + : (msg.content as ContentPart[] | null)?.map((p) => p.text || '').join('') || ''; + if (text) blocks.push({ type: 'text', text }); + // Tool calls + if (msg.tool_calls) { + for (const tc of msg.tool_calls) { + blocks.push({ + type: 'tool_use', + id: tc.id, + name: tc.function.name, + input: safeParseJson(tc.function.arguments), + }); + } + } + if (blocks.length > 0) { + anthropicMessages.push({ role: 'assistant', content: blocks }); + } + continue; + } + + if (msg.role === 'tool') { + // Tool results must be in a user message in Anthropic format + const toolResultBlock: AnthropicContentBlock = { + type: 'tool_result', + tool_use_id: msg.tool_call_id || '', + content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content), + }; + // If the last message is already a user message, append to it (merge consecutive tool results) + const lastMsg = anthropicMessages[anthropicMessages.length - 1]; + if (lastMsg && lastMsg.role === 'user') { + lastMsg.content.push(toolResultBlock); + } else { + anthropicMessages.push({ role: 'user', content: [toolResultBlock] }); + } + continue; + } + } + + // Anthropic requires alternating user/assistant messages. Merge consecutive same-role messages. + const merged = mergeConsecutiveMessages(anthropicMessages); + + const request: AnthropicRequest = { + model: modelId, + messages: merged, + max_tokens: (body.max_tokens as number) || DEFAULT_MAX_TOKENS, + stream: !!body.stream, + }; + + if (systemMessages.length > 0) { + request.system = systemMessages.join('\n\n'); + } + + // Translate tools + if (body.tools && Array.isArray(body.tools)) { + request.tools = (body.tools as OpenAITool[]).map((t) => ({ + name: t.function.name, + description: t.function.description, + input_schema: t.function.parameters || { type: 'object', properties: {} }, + })); + } + + if (body.temperature != null) request.temperature = body.temperature as number; + if (body.top_p != null) request.top_p = body.top_p as number; + if (body.stop && Array.isArray(body.stop)) request.stop_sequences = body.stop as string[]; + + return request; +} + +// ============================================================================= +// Response Translation: Anthropic → OpenAI (non-streaming) +// ============================================================================= + +/** + * Translate an Anthropic Messages API response into OpenAI chat completions format. + */ +export function translateResponseToOpenAI(response: AnthropicResponse): Record { + const toolCalls: OpenAIToolCall[] = []; + const textParts: string[] = []; + + for (const block of response.content) { + if (block.type === 'text') { + textParts.push(block.text); + } else if (block.type === 'tool_use') { + toolCalls.push({ + id: block.id, + type: 'function', + function: { + name: block.name, + arguments: typeof block.input === 'string' ? block.input : JSON.stringify(block.input), + }, + }); + } + } + + const message: Record = { + role: 'assistant', + content: textParts.join('') || null, + }; + if (toolCalls.length > 0) { + message.tool_calls = toolCalls; + } + + return { + id: `chatcmpl-${response.id}`, + object: 'chat.completion', + created: Math.floor(Date.now() / 1000), + model: response.model, + choices: [{ + index: 0, + message, + finish_reason: mapStopReason(response.stop_reason), + }], + usage: { + prompt_tokens: response.usage.input_tokens, + completion_tokens: response.usage.output_tokens, + total_tokens: response.usage.input_tokens + response.usage.output_tokens, + }, + }; +} + +// ============================================================================= +// Streaming Translation: Anthropic SSE → OpenAI SSE +// ============================================================================= + +/** + * Create a TransformStream that converts Anthropic streaming events into + * OpenAI-compatible SSE delta format. + */ +export function createAnthropicToOpenAIStream(model: string): TransformStream { + const encoder = new TextEncoder(); + const decoder = new TextDecoder(); + let buffer = ''; + let currentToolCallIndex = 0; + let messageId = ''; + + return new TransformStream({ + transform(chunk, controller) { + buffer += decoder.decode(chunk, { stream: true }); + const lines = buffer.split('\n'); + // Keep incomplete last line in buffer + buffer = lines.pop() || ''; + + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + const data = line.slice(6).trim(); + if (data === '[DONE]') { + controller.enqueue(encoder.encode('data: [DONE]\n\n')); + continue; + } + + let event: Record; + try { + event = JSON.parse(data); + } catch { + continue; + } + + const openAIDelta = translateStreamEvent(event, model, messageId, currentToolCallIndex); + if (openAIDelta) { + if (openAIDelta.messageId) messageId = openAIDelta.messageId; + if (openAIDelta.toolCallIndex !== undefined) currentToolCallIndex = openAIDelta.toolCallIndex; + if (openAIDelta.chunk) { + controller.enqueue(encoder.encode(`data: ${JSON.stringify(openAIDelta.chunk)}\n\n`)); + } + } + } + }, + flush(controller) { + // Send final [DONE] if not already sent + controller.enqueue(encoder.encode('data: [DONE]\n\n')); + }, + }); +} + +// ============================================================================= +// Helpers +// ============================================================================= + +function normalizeContent(content: string | ContentPart[] | null): AnthropicContentBlock[] { + if (!content) return [{ type: 'text', text: '' }]; + if (typeof content === 'string') return [{ type: 'text', text: content }]; + return content + .filter((p) => p.type === 'text' && p.text) + .map((p) => ({ type: 'text' as const, text: p.text! })); +} + +function mergeConsecutiveMessages(messages: AnthropicMessage[]): AnthropicMessage[] { + const merged: AnthropicMessage[] = []; + for (const msg of messages) { + const last = merged[merged.length - 1]; + if (last && last.role === msg.role) { + last.content.push(...msg.content); + } else { + merged.push({ ...msg, content: [...msg.content] }); + } + } + return merged; +} + +function safeParseJson(str: string): unknown { + try { + return JSON.parse(str); + } catch { + return str; + } +} + +function mapStopReason(reason: string | null): string { + switch (reason) { + case 'end_turn': return 'stop'; + case 'max_tokens': return 'length'; + case 'tool_use': return 'tool_calls'; + case 'stop_sequence': return 'stop'; + default: return 'stop'; + } +} + +interface StreamDelta { + chunk?: Record; + messageId?: string; + toolCallIndex?: number; +} + +function translateStreamEvent( + event: Record, + model: string, + messageId: string, + toolCallIndex: number, +): StreamDelta | null { + const type = event.type as string; + + switch (type) { + case 'message_start': { + const msg = event.message as Record | undefined; + const id = (msg?.id as string) || `chatcmpl-${Date.now()}`; + return { + messageId: id, + chunk: { + id: `chatcmpl-${id}`, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model, + choices: [{ + index: 0, + delta: { role: 'assistant', content: '' }, + finish_reason: null, + }], + }, + }; + } + + case 'content_block_start': { + const block = event.content_block as Record | undefined; + if (block?.type === 'tool_use') { + const newIndex = toolCallIndex; + return { + toolCallIndex: newIndex + 1, + chunk: { + id: `chatcmpl-${messageId}`, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model, + choices: [{ + index: 0, + delta: { + tool_calls: [{ + index: newIndex, + id: block.id as string, + type: 'function', + function: { name: block.name as string, arguments: '' }, + }], + }, + finish_reason: null, + }], + }, + }; + } + return null; + } + + case 'content_block_delta': { + const delta = event.delta as Record | undefined; + if (!delta) return null; + + if (delta.type === 'text_delta') { + return { + chunk: { + id: `chatcmpl-${messageId}`, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model, + choices: [{ + index: 0, + delta: { content: delta.text as string }, + finish_reason: null, + }], + }, + }; + } + + if (delta.type === 'input_json_delta') { + return { + chunk: { + id: `chatcmpl-${messageId}`, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model, + choices: [{ + index: 0, + delta: { + tool_calls: [{ + index: toolCallIndex - 1, + function: { arguments: delta.partial_json as string }, + }], + }, + finish_reason: null, + }], + }, + }; + } + + return null; + } + + case 'message_delta': { + const delta = event.delta as Record | undefined; + const stopReason = delta?.stop_reason as string | null; + return { + chunk: { + id: `chatcmpl-${messageId}`, + object: 'chat.completion.chunk', + created: Math.floor(Date.now() / 1000), + model, + choices: [{ + index: 0, + delta: {}, + finish_reason: mapStopReason(stopReason), + }], + usage: event.usage as Record | undefined, + }, + }; + } + + default: + return null; + } +} diff --git a/apps/api/src/services/trial/bridge.ts b/apps/api/src/services/trial/bridge.ts new file mode 100644 index 000000000..d3c85961f --- /dev/null +++ b/apps/api/src/services/trial/bridge.ts @@ -0,0 +1,174 @@ +/** + * Trial event bridge — hooks that fan ACP / MCP events into `trial.*` SSE. + * + * These are fire-and-forget, catch-everything helpers designed to be dropped + * into hot paths (ProjectData DO, MCP handlers) without any risk of blocking + * or crashing the primary flow. Every call path: + * 1. Looks up the trial record for the project (KV round-trip). + * 2. If no trial exists, returns silently (this is the common case — most + * projects are NOT trials). + * 3. Emits the appropriate `trial.*` event via the TrialEventBus DO. + * + * The non-trial short-circuit means the overhead on normal project traffic + * is a single `env.KV.get()` — acceptable, and scopable-down later by caching + * a boolean flag on the ProjectData DO if needed. + */ +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { emitTrialEventForProject } from './trial-runner'; +import { readTrialByProject } from './trial-store'; + +/** + * Emit `trial.ready` or `trial.error` when an ACP session transitions. + * + * Ready is inferred from the first `running` transition on a trial's discovery + * session — that's when the VM agent has confirmed the agent process is + * attached and actively producing turns. `trial.error` is emitted on `failed`. + * + * Safe to call on non-trial projects; the helper silently no-ops. + */ +export async function bridgeAcpSessionTransition( + env: Env, + projectId: string, + toStatus: string, + opts: { workspaceUrl?: string | null; errorMessage?: string | null } = {}, +): Promise { + try { + const record = await readTrialByProject(env, projectId); + if (!record) return; + + if (toStatus === 'running') { + const workspaceUrl = + opts.workspaceUrl ?? + (record.workspaceId ? `https://ws-${record.workspaceId}.${env.BASE_DOMAIN}` : ''); + await emitTrialEventForProject(env, projectId, { + type: 'trial.ready', + trialId: record.trialId, + projectId: record.projectId, + workspaceUrl, + at: Date.now(), + }); + return; + } + + if (toStatus === 'failed') { + await emitTrialEventForProject(env, projectId, { + type: 'trial.error', + // TrialErrorCode is a shared enum; pass a free-form string through + // the same cast-through-unknown pattern used by the DO failure path. + error: 'acp_session_failed' as never, + message: opts.errorMessage ?? 'Discovery agent session failed', + at: Date.now(), + }); + } + } catch (err) { + log.warn('trial_bridge.acp_transition_failed', { + projectId, + toStatus, + error: err instanceof Error ? err.message : String(err), + }); + } +} + +/** + * Emit `trial.knowledge` when the discovery agent adds a knowledge observation + * via MCP. Called from `handleAddKnowledge` in the MCP tool handler. + */ +export async function bridgeKnowledgeAdded( + env: Env, + projectId: string, + entity: string, + observation: string, +): Promise { + try { + const record = await readTrialByProject(env, projectId); + if (!record) return; + await emitTrialEventForProject(env, projectId, { + type: 'trial.knowledge', + entity, + observation, + at: Date.now(), + }); + } catch (err) { + log.warn('trial_bridge.knowledge_failed', { + projectId, + error: err instanceof Error ? err.message : String(err), + }); + } +} + +/** + * Emit `trial.agent_activity` when the discovery agent produces output. + * Called from the message persistence path for assistant/tool/thinking roles. + * + * Truncates long text to keep SSE payloads small — the feed only needs a + * summary of what the agent is doing, not the full output. + */ +export async function bridgeAgentActivity( + env: Env, + projectId: string, + messages: Array<{ + role: string; + content: string; + toolMetadata?: unknown; + }>, +): Promise { + try { + const record = await readTrialByProject(env, projectId); + if (!record) return; + + for (const msg of messages) { + // Only surface agent-facing roles, skip user messages + if (msg.role !== 'assistant' && msg.role !== 'tool' && msg.role !== 'thinking') continue; + // Skip empty content + const text = (msg.content || '').trim(); + if (!text) continue; + + const toolName = + msg.role === 'tool' && msg.toolMetadata && typeof msg.toolMetadata === 'object' + ? (msg.toolMetadata as Record).toolName as string | undefined + : undefined; + + await emitTrialEventForProject(env, projectId, { + type: 'trial.agent_activity', + role: msg.role as 'assistant' | 'tool' | 'thinking', + text: text.length > 200 ? text.slice(0, 200) + '…' : text, + ...(toolName ? { toolName } : {}), + at: Date.now(), + }); + } + } catch (err) { + log.warn('trial_bridge.agent_activity_failed', { + projectId, + error: err instanceof Error ? err.message : String(err), + }); + } +} + +/** + * Emit `trial.idea` when the discovery agent creates an idea via MCP. + */ +export async function bridgeIdeaCreated( + env: Env, + projectId: string, + ideaId: string, + title: string, + summary: string, +): Promise { + try { + const record = await readTrialByProject(env, projectId); + if (!record) return; + await emitTrialEventForProject(env, projectId, { + type: 'trial.idea', + ideaId, + title, + summary, + at: Date.now(), + }); + } catch (err) { + log.warn('trial_bridge.idea_failed', { + projectId, + error: err instanceof Error ? err.message : String(err), + }); + } +} diff --git a/apps/api/src/services/trial/cookies.ts b/apps/api/src/services/trial/cookies.ts new file mode 100644 index 000000000..eb5a72450 --- /dev/null +++ b/apps/api/src/services/trial/cookies.ts @@ -0,0 +1,233 @@ +/** + * Trial cookie signing / verification. + * + * Two cookies are issued to anonymous trial visitors: + * - sam_trial_fingerprint (7-day lifetime) — signed UUID; identifies the same + * anonymous visitor across subsequent page loads. + * - sam_trial_claim (48-hour lifetime) — signed JSON payload + * { trialId, projectId, issuedAt, expiresAt } — presented on the OAuth callback + * so the API can re-parent `projects.user_id` from the sentinel system user to + * the newly-authenticated GitHub user. + * + * HMAC-SHA256, base64url, constant-time compare. The secret is a Worker secret + * (`TRIAL_CLAIM_TOKEN_SECRET`) that MUST be set when trials are enabled. + */ + +import { + TRIAL_COOKIE_CLAIM_NAME, + TRIAL_COOKIE_FINGERPRINT_NAME, +} from '@simple-agent-manager/shared'; + +// --------------------------------------------------------------------------- +// Defaults (Principle XI: all limits configurable, with DEFAULT_* constants) +// --------------------------------------------------------------------------- + +export const DEFAULT_TRIAL_FINGERPRINT_TTL_SEC = 60 * 60 * 24 * 7; // 7 days +export const DEFAULT_TRIAL_CLAIM_TTL_MS = 1000 * 60 * 60 * 48; // 48 hours + +// --------------------------------------------------------------------------- +// Claim payload +// --------------------------------------------------------------------------- + +export interface TrialClaimPayload { + trialId: string; + projectId: string; + /** epoch ms — when the token was issued */ + issuedAt: number; + /** epoch ms — absolute expiry */ + expiresAt: number; +} + +// --------------------------------------------------------------------------- +// HMAC helpers +// --------------------------------------------------------------------------- + +function base64urlEncode(bytes: Uint8Array): string { + // btoa requires a binary string; Uint8Array iteration is safe for bytes. + let binary = ''; + for (let i = 0; i < bytes.length; i++) binary += String.fromCharCode(bytes[i]!); + return btoa(binary).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, ''); +} + +function base64urlDecode(s: string): Uint8Array { + const pad = s.length % 4 === 0 ? '' : '='.repeat(4 - (s.length % 4)); + const b64 = s.replace(/-/g, '+').replace(/_/g, '/') + pad; + const binary = atob(b64); + const out = new Uint8Array(binary.length); + for (let i = 0; i < binary.length; i++) out[i] = binary.charCodeAt(i); + return out; +} + +async function hmacKey(secret: string): Promise { + return crypto.subtle.importKey( + 'raw', + new TextEncoder().encode(secret), + { name: 'HMAC', hash: 'SHA-256' }, + false, + ['sign', 'verify'] + ); +} + +async function hmacSign(secret: string, data: string): Promise { + const key = await hmacKey(secret); + const sig = await crypto.subtle.sign('HMAC', key, new TextEncoder().encode(data)); + return base64urlEncode(new Uint8Array(sig)); +} + +/** Constant-time comparison of two strings (prevents timing oracles on HMAC). */ +function timingSafeEqual(a: string, b: string): boolean { + if (a.length !== b.length) return false; + let diff = 0; + for (let i = 0; i < a.length; i++) { + diff |= a.charCodeAt(i) ^ b.charCodeAt(i); + } + return diff === 0; +} + +// --------------------------------------------------------------------------- +// Fingerprint (signed UUID) +// --------------------------------------------------------------------------- + +/** + * Sign a fingerprint UUID — returns "." (dotted, base64url sig). + * The value is opaque to callers; only `verifyFingerprint` re-reads it. + */ +export async function signFingerprint( + uuid: string, + secret: string +): Promise { + const sig = await hmacSign(secret, uuid); + return `${uuid}.${sig}`; +} + +/** Returns the UUID on success, or `null` if the signature is invalid. */ +export async function verifyFingerprint( + value: string, + secret: string +): Promise { + const dot = value.lastIndexOf('.'); + if (dot <= 0 || dot === value.length - 1) return null; + const uuid = value.slice(0, dot); + const providedSig = value.slice(dot + 1); + const expectedSig = await hmacSign(secret, uuid); + return timingSafeEqual(providedSig, expectedSig) ? uuid : null; +} + +// --------------------------------------------------------------------------- +// Claim token (signed JSON) +// --------------------------------------------------------------------------- + +/** + * Sign a claim payload — returns ".". The entire string + * is then placed in the `sam_trial_claim` cookie. + */ +export async function signClaimToken( + payload: TrialClaimPayload, + secret: string +): Promise { + const body = base64urlEncode(new TextEncoder().encode(JSON.stringify(payload))); + const sig = await hmacSign(secret, body); + return `${body}.${sig}`; +} + +export type ClaimVerifyFailure = + | { ok: false; reason: 'malformed' } + | { ok: false; reason: 'bad_signature' } + | { ok: false; reason: 'expired' }; + +export type ClaimVerifyResult = + | { ok: true; payload: TrialClaimPayload } + | ClaimVerifyFailure; + +/** + * Verify a signed claim token. Returns the decoded payload or a reason tag on + * failure. On `bad_signature` the provided HMAC did NOT match — constant-time + * compared to prevent timing oracles. + */ +export async function verifyClaimToken( + token: string, + secret: string, + now: number = Date.now() +): Promise { + const dot = token.lastIndexOf('.'); + if (dot <= 0 || dot === token.length - 1) return { ok: false, reason: 'malformed' }; + const body = token.slice(0, dot); + const providedSig = token.slice(dot + 1); + + const expectedSig = await hmacSign(secret, body); + if (!timingSafeEqual(providedSig, expectedSig)) { + return { ok: false, reason: 'bad_signature' }; + } + + let payload: TrialClaimPayload; + try { + const json = new TextDecoder().decode(base64urlDecode(body)); + payload = JSON.parse(json) as TrialClaimPayload; + } catch { + return { ok: false, reason: 'malformed' }; + } + + if ( + typeof payload.trialId !== 'string' || + typeof payload.projectId !== 'string' || + typeof payload.issuedAt !== 'number' || + typeof payload.expiresAt !== 'number' + ) { + return { ok: false, reason: 'malformed' }; + } + + if (now >= payload.expiresAt) return { ok: false, reason: 'expired' }; + return { ok: true, payload }; +} + +// --------------------------------------------------------------------------- +// Cookie string builders (HttpOnly; Secure; SameSite=Lax) +// --------------------------------------------------------------------------- + +function buildCookieString( + name: string, + value: string, + opts: { maxAgeSec: number; secure?: boolean; domain?: string } +): string { + const parts = [ + `${name}=${value}`, + 'Path=/', + 'HttpOnly', + 'SameSite=Lax', + `Max-Age=${opts.maxAgeSec}`, + ]; + if (opts.secure !== false) parts.push('Secure'); + if (opts.domain) parts.push(`Domain=${opts.domain}`); + return parts.join('; '); +} + +export function buildFingerprintCookie( + signedValue: string, + opts: { secure?: boolean; domain?: string; maxAgeSec?: number } = {} +): string { + return buildCookieString(TRIAL_COOKIE_FINGERPRINT_NAME, signedValue, { + maxAgeSec: opts.maxAgeSec ?? DEFAULT_TRIAL_FINGERPRINT_TTL_SEC, + secure: opts.secure, + domain: opts.domain, + }); +} + +export function buildClaimCookie( + token: string, + opts: { secure?: boolean; domain?: string; maxAgeSec?: number } = {} +): string { + return buildCookieString(TRIAL_COOKIE_CLAIM_NAME, token, { + maxAgeSec: opts.maxAgeSec ?? Math.floor(DEFAULT_TRIAL_CLAIM_TTL_MS / 1000), + secure: opts.secure, + domain: opts.domain, + }); +} + +export function clearClaimCookie(opts: { domain?: string } = {}): string { + // Max-Age=0 instructs the browser to drop the cookie immediately. + return buildCookieString(TRIAL_COOKIE_CLAIM_NAME, '', { + maxAgeSec: 0, + secure: true, + domain: opts.domain, + }); +} diff --git a/apps/api/src/services/trial/discovery-prompt.ts b/apps/api/src/services/trial/discovery-prompt.ts new file mode 100644 index 000000000..bbe2372eb --- /dev/null +++ b/apps/api/src/services/trial/discovery-prompt.ts @@ -0,0 +1,24 @@ +/** + * Discovery prompt — the canned first message sent to the agent when a trial + * workspace becomes ready. The agent uses its ~45s demo budget to explore the + * repo, surface high-value knowledge entities, and seed a few starter ideas. + * + * When this prompt is edited, bump DISCOVERY_PROMPT_VERSION so downstream + * caches / analytics keyed on the version can refresh. + */ + +export const DISCOVERY_PROMPT_VERSION = '2026-04-18-v1'; + +export const DISCOVERY_PROMPT = `You are exploring a codebase for a developer who has just discovered SAM. You have about 45 seconds to demonstrate SAM's capabilities and produce an immediately useful first impression. The developer has not authenticated yet — they landed here from a URL and chose to try SAM with this repository. + +Your goals, in order: + +1. **Orient quickly.** Read the repo root (\`README.md\`, \`package.json\` or equivalent, top-level directory listing) and form a compact mental model of what this project is and what stack it uses. Keep this to ~3 files. + +2. **Seed 3–5 knowledge entities.** Use \`add_knowledge\` to record the most load-bearing facts you discover: the primary language, build system, entry points, notable domain concepts, and any conventions that would trip up a new contributor. Include a source file reference where possible. Prefer depth over breadth — five well-grounded entries beat fifteen shallow ones. + +3. **Propose 2–3 starter ideas.** Use \`create_idea\` to suggest small, concrete improvements or explorations a developer could dispatch next (e.g. "Add a CHANGELOG generator", "Document the retry policy in \`src/client.ts\`"). Each idea should have a clear title, a one-paragraph summary, and — if possible — a pointer to the specific file(s) involved. + +4. **Stop before you over-explore.** Leave quickly once you have delivered the above; do not attempt to refactor code or run long-form tasks. The human will take over when they claim the workspace. + +Do NOT try to run build or test commands. Do NOT edit or create files in the repo. Your entire output for this turn is: a short human-facing summary (<= 4 sentences) describing what you found, plus the MCP calls above.`; diff --git a/apps/api/src/services/trial/github-knowledge.ts b/apps/api/src/services/trial/github-knowledge.ts new file mode 100644 index 000000000..429b60a46 --- /dev/null +++ b/apps/api/src/services/trial/github-knowledge.ts @@ -0,0 +1,202 @@ +/** + * Fast-path GitHub knowledge probe. + * + * Before the discovery agent even boots, we hit a handful of unauthenticated + * GitHub REST endpoints to surface quick wins (description, languages, topics, + * stars, README first paragraph) as `trial.knowledge` events. These arrive on + * the SSE stream within ~3s of POST /api/trial/create returning, giving the + * user immediate feedback while the VM provisions in the background. + * + * Design notes: + * - Fire-and-forget: every call is wrapped in try/catch; errors never bubble. + * - Per-request timeout via AbortController; total events capped via env. + * - No auth header: the ~60 req/hour unauthenticated rate limit is acceptable + * for trial onboarding (one probe per trial, five GH calls per probe). + * - All knobs are env-configurable (Constitution Principle XI). + */ +import { + DEFAULT_TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS, + DEFAULT_TRIAL_KNOWLEDGE_MAX_EVENTS, +} from '@simple-agent-manager/shared'; + +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { emitTrialEvent } from './trial-runner'; + +function resolveTimeoutMs(env: Env): number { + const raw = env.TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS; + if (!raw) return DEFAULT_TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS; + const n = Number(raw); + return Number.isFinite(n) && n > 0 ? n : DEFAULT_TRIAL_KNOWLEDGE_GITHUB_TIMEOUT_MS; +} + +function resolveMaxEvents(env: Env): number { + const raw = env.TRIAL_KNOWLEDGE_MAX_EVENTS; + if (!raw) return DEFAULT_TRIAL_KNOWLEDGE_MAX_EVENTS; + const n = Number(raw); + return Number.isFinite(n) && n > 0 ? n : DEFAULT_TRIAL_KNOWLEDGE_MAX_EVENTS; +} + +/** Standard headers for unauthenticated GitHub REST calls from the Worker. */ +const GH_HEADERS: Record = { + accept: 'application/vnd.github+json', + 'user-agent': 'sam-trial-onboarding', +}; + +async function fetchJson( + url: string, + timeoutMs: number, + fetchFn: typeof fetch = fetch +): Promise { + const ac = new AbortController(); + const timer = setTimeout(() => ac.abort(), timeoutMs); + try { + const resp = await fetchFn(url, { headers: GH_HEADERS, signal: ac.signal }); + if (!resp.ok) return null; + return await resp.json(); + } catch { + return null; + } finally { + clearTimeout(timer); + } +} + +async function fetchText( + url: string, + timeoutMs: number, + fetchFn: typeof fetch = fetch +): Promise { + const ac = new AbortController(); + const timer = setTimeout(() => ac.abort(), timeoutMs); + try { + const resp = await fetchFn(url, { + headers: { ...GH_HEADERS, accept: 'application/vnd.github.raw' }, + signal: ac.signal, + }); + if (!resp.ok) return null; + return await resp.text(); + } catch { + return null; + } finally { + clearTimeout(timer); + } +} + +/** + * Probe GitHub for quick-win observations and emit them as `trial.knowledge` + * events. Completes within ~timeoutMs × N-calls (worst case) even if the + * endpoints hang, because every request is AbortController-bounded. + * + * Errors never bubble — callers should kick this off via `waitUntil` and + * return their response immediately. + */ +export async function emitGithubKnowledgeEvents( + env: Env, + trialId: string, + repo: { owner: string; name: string }, + opts: { fetchFn?: typeof fetch } = {} +): Promise { + const timeoutMs = resolveTimeoutMs(env); + const maxEvents = resolveMaxEvents(env); + const fetchFn = opts.fetchFn ?? fetch; + let emitted = 0; + + const emit = async (entity: string, observation: string): Promise => { + if (emitted >= maxEvents) return; + emitted++; + try { + await emitTrialEvent(env, trialId, { + type: 'trial.knowledge', + entity, + observation, + at: Date.now(), + }); + } catch (err) { + log.warn('trial.github_knowledge.emit_failed', { + trialId, + entity, + error: err instanceof Error ? err.message : String(err), + }); + } + }; + + try { + // Run the three independent probes in parallel. Each is independently + // bounded by timeoutMs — a hanging readme fetch won't block repo metadata. + const base = `https://api.github.com/repos/${repo.owner}/${repo.name}`; + const [repoMeta, languages, readme] = await Promise.all([ + fetchJson(base, timeoutMs, fetchFn) as Promise< + | { + description?: string | null; + language?: string | null; + stargazers_count?: number; + topics?: string[]; + license?: { spdx_id?: string | null } | null; + default_branch?: string; + } + | null + >, + fetchJson(`${base}/languages`, timeoutMs, fetchFn) as Promise< + Record | null + >, + fetchText(`${base}/readme`, timeoutMs, fetchFn), + ]); + + if (repoMeta) { + if (typeof repoMeta.description === 'string' && repoMeta.description.trim()) { + await emit('repository', `Description: ${repoMeta.description.trim()}`); + } + if (typeof repoMeta.language === 'string' && repoMeta.language) { + await emit('repository', `Primary language: ${repoMeta.language}`); + } + if ( + typeof repoMeta.stargazers_count === 'number' && + repoMeta.stargazers_count > 0 + ) { + await emit('repository', `Stars: ${repoMeta.stargazers_count}`); + } + if (Array.isArray(repoMeta.topics) && repoMeta.topics.length > 0) { + await emit('repository', `Topics: ${repoMeta.topics.slice(0, 8).join(', ')}`); + } + if (repoMeta.license && typeof repoMeta.license.spdx_id === 'string') { + await emit('repository', `License: ${repoMeta.license.spdx_id}`); + } + } + + if (languages && typeof languages === 'object') { + const entries = Object.entries(languages).sort((a, b) => b[1] - a[1]); + if (entries.length > 0) { + const top = entries + .slice(0, 5) + .map(([lang]) => lang) + .join(', '); + await emit('repository', `Languages (by bytes): ${top}`); + } + } + + if (readme && readme.length > 0) { + // First non-empty paragraph, stripped of markdown heading markers. + const paragraphs = readme + .split(/\n{2,}/) + .map((p) => p.trim()) + .filter((p) => p.length > 0); + const firstParagraph = paragraphs.find( + (p) => !/^#{1,6}\s/.test(p) && !/^!\[/.test(p) && p.length > 20 + ); + if (firstParagraph) { + const snippet = + firstParagraph.length > 280 + ? `${firstParagraph.slice(0, 280)}…` + : firstParagraph; + await emit('repository', `README: ${snippet}`); + } + } + } catch (err) { + // Belt-and-suspenders — individual probes already swallow, but any + // unforeseen throw in the aggregation layer must not escape. + log.warn('trial.github_knowledge.unexpected_error', { + trialId, + error: err instanceof Error ? err.message : String(err), + }); + } +} diff --git a/apps/api/src/services/trial/helpers.ts b/apps/api/src/services/trial/helpers.ts new file mode 100644 index 000000000..e4b1ce1a0 --- /dev/null +++ b/apps/api/src/services/trial/helpers.ts @@ -0,0 +1,180 @@ +/** + * Shared helpers for the trial onboarding routes. + * + * Kept in a single module so create.ts / status.ts / waitlist.ts share the + * same month-key math, repo-URL canonicalisation, and env-var fallback + * constants (Principle XI — no hardcoded values; every limit has a + * DEFAULT_* constant and an env override). + */ +import type { Env } from '../../env'; + +// --------------------------------------------------------------------------- +// Defaults (Principle XI) +// --------------------------------------------------------------------------- + +/** Monthly cap on anonymous trial creations. */ +export const DEFAULT_TRIAL_MONTHLY_CAP = 1500; +/** Trial workspace TTL in ms (20 min). */ +export const DEFAULT_TRIAL_WORKSPACE_TTL_MS = 20 * 60 * 1000; +/** GitHub repo size upper bound in KB (GitHub `size` field is reported in KB). */ +export const DEFAULT_TRIAL_REPO_MAX_KB = 500 * 1024; // 500 MB +/** Timeout for the GitHub repo metadata probe. */ +export const DEFAULT_TRIAL_GITHUB_TIMEOUT_MS = 5_000; +/** Retention window (hours) after which expired/failed/claimed trials are reaped. */ +export const DEFAULT_TRIAL_DATA_RETENTION_HOURS = 24 * 7; // 7 days +/** Stale-counter months to keep in TrialCounter DO SQLite. */ +export const DEFAULT_TRIAL_COUNTER_KEEP_MONTHS = 3; +/** How long to wait after the reset date before purging notified waitlist rows. */ +export const DEFAULT_TRIAL_WAITLIST_PURGE_DAYS = 30; + +// --------------------------------------------------------------------------- +// Env var resolution +// --------------------------------------------------------------------------- + +export function resolveMonthlyCap(env: Env): number { + const raw = env.TRIAL_MONTHLY_CAP; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_MONTHLY_CAP; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed >= 0 ? parsed : DEFAULT_TRIAL_MONTHLY_CAP; +} + +export function resolveWorkspaceTtlMs(env: Env): number { + const raw = env.TRIAL_WORKSPACE_TTL_MS; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_WORKSPACE_TTL_MS; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed > 0 + ? parsed + : DEFAULT_TRIAL_WORKSPACE_TTL_MS; +} + +export function resolveRepoMaxKb(env: Env): number { + const raw = env.TRIAL_REPO_MAX_KB; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_REPO_MAX_KB; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed > 0 ? parsed : DEFAULT_TRIAL_REPO_MAX_KB; +} + +export function resolveGithubTimeoutMs(env: Env): number { + const raw = env.TRIAL_GITHUB_TIMEOUT_MS; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_GITHUB_TIMEOUT_MS; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed > 0 + ? parsed + : DEFAULT_TRIAL_GITHUB_TIMEOUT_MS; +} + +export function resolveRetentionHours(env: Env): number { + const raw = env.TRIAL_DATA_RETENTION_HOURS; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_DATA_RETENTION_HOURS; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed > 0 + ? parsed + : DEFAULT_TRIAL_DATA_RETENTION_HOURS; +} + +export function resolveCounterKeepMonths(env: Env): number { + const raw = env.TRIAL_COUNTER_KEEP_MONTHS; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_COUNTER_KEEP_MONTHS; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed > 0 + ? Math.floor(parsed) + : DEFAULT_TRIAL_COUNTER_KEEP_MONTHS; +} + +export function resolveWaitlistPurgeDays(env: Env): number { + const raw = env.TRIAL_WAITLIST_PURGE_DAYS; + if (raw === undefined || raw === null || raw === '') { + return DEFAULT_TRIAL_WAITLIST_PURGE_DAYS; + } + const parsed = Number(raw); + return Number.isFinite(parsed) && parsed > 0 + ? Math.floor(parsed) + : DEFAULT_TRIAL_WAITLIST_PURGE_DAYS; +} + +// --------------------------------------------------------------------------- +// Month key math (UTC) +// --------------------------------------------------------------------------- + +/** `YYYY-MM` (UTC) for `now` — matches TrialCounter keyspace. */ +export function currentMonthKey(now: number = Date.now()): string { + const d = new Date(now); + const y = d.getUTCFullYear(); + const m = d.getUTCMonth() + 1; + return `${y}-${m.toString().padStart(2, '0')}`; +} + +/** ISO date (YYYY-MM-01) of the first day of the next UTC month. */ +export function nextMonthResetDate(now: number = Date.now()): string { + const d = new Date(now); + const next = new Date(Date.UTC(d.getUTCFullYear(), d.getUTCMonth() + 1, 1)); + const y = next.getUTCFullYear(); + const m = next.getUTCMonth() + 1; + return `${y}-${m.toString().padStart(2, '0')}-01`; +} + +/** + * Shift a month key by `delta` months (negative = earlier). Used by the + * monthly rollover audit to compute the oldest key we want to retain. + */ +export function shiftMonthKey(monthKey: string, delta: number): string { + const match = /^(\d{4})-(\d{2})$/.exec(monthKey); + if (!match) throw new Error(`invalid month key: ${monthKey}`); + const y = Number(match[1]); + const m = Number(match[2]); + // JS Date handles month overflow/underflow cleanly. + const d = new Date(Date.UTC(y, m - 1 + delta, 1)); + const ny = d.getUTCFullYear(); + const nm = d.getUTCMonth() + 1; + return `${ny}-${nm.toString().padStart(2, '0')}`; +} + +// --------------------------------------------------------------------------- +// GitHub repo URL canonicalisation +// --------------------------------------------------------------------------- + +export interface ParsedRepoUrl { + owner: string; + name: string; + canonical: string; // https://github.com/owner/name +} + +/** + * Parse a GitHub public-repo URL. Accepts the forms matched by + * GITHUB_REPO_URL_REGEX in the shared Valibot schema; strips `.git` and + * trailing slashes. Returns null on structural mismatch — the route + * should already have validated the shape; this is a defense-in-depth parse. + */ +export function parseGithubRepoUrl(url: string): ParsedRepoUrl | null { + const trimmed = url.trim().replace(/\/+$/, '').replace(/\.git$/, ''); + const match = + /^https:\/\/github\.com\/([A-Za-z0-9](?:[A-Za-z0-9-]{0,37}[A-Za-z0-9])?)\/([A-Za-z0-9_.-]{1,100})$/.exec( + trimmed + ); + if (!match) return null; + const owner = match[1]!; + const name = match[2]!; + return { owner, name, canonical: `https://github.com/${owner}/${name}` }; +} + +// --------------------------------------------------------------------------- +// TrialCounter DO access +// --------------------------------------------------------------------------- + +/** Return the singleton `TrialCounter` stub (keyed by `global`). */ +export function getTrialCounterStub(env: Env): DurableObjectStub { + const id = env.TRIAL_COUNTER.idFromName('global'); + return env.TRIAL_COUNTER.get(id); +} diff --git a/apps/api/src/services/trial/kill-switch.ts b/apps/api/src/services/trial/kill-switch.ts new file mode 100644 index 000000000..5bafa75ab --- /dev/null +++ b/apps/api/src/services/trial/kill-switch.ts @@ -0,0 +1,61 @@ +/** + * Trial kill-switch. + * + * Trials can be disabled without deploying by setting KV key `trials:enabled` + * (configurable via env.TRIALS_ENABLED_KV_KEY) to anything other than `"true"`. + * + * The value is cached in-memory for 30s (configurable via + * env.TRIAL_KILL_SWITCH_CACHE_MS) because every POST /api/trial/create reads + * it on the hot path. + */ +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; + +const DEFAULT_KILL_SWITCH_CACHE_MS = 30_000; +const DEFAULT_TRIALS_ENABLED_KV_KEY = 'trials:enabled'; + +interface CacheEntry { + enabled: boolean; + /** epoch ms when this entry becomes stale */ + expiresAt: number; +} + +// Module-scoped cache — Workers re-use the isolate across requests within an +// instance, so this gives us the intended "last value for up to TTL" behavior. +let cache: CacheEntry | null = null; + +/** Exported for tests only. */ +export function __resetKillSwitchCacheForTest(): void { + cache = null; +} + +/** + * Returns `true` when trials are enabled. Default is **disabled** — an + * operator must explicitly set the KV flag to `"true"` to turn trials on. + */ +export async function isTrialsEnabled( + env: Env, + now: number = Date.now() +): Promise { + if (cache && now < cache.expiresAt) return cache.enabled; + + const key = env.TRIALS_ENABLED_KV_KEY ?? DEFAULT_TRIALS_ENABLED_KV_KEY; + const ttl = Number(env.TRIAL_KILL_SWITCH_CACHE_MS ?? DEFAULT_KILL_SWITCH_CACHE_MS); + + let enabled = false; + try { + const value = await env.KV.get(key); + enabled = value === 'true'; + } catch (err) { + // KV outages MUST fail closed — the monthly cap is defended by the DO + // counter, so preferring "disabled" over "enabled" during an outage is safe. + log.error('trial.kill_switch.kv_read_failed', { + key, + error: err instanceof Error ? err.message : String(err), + }); + enabled = false; + } + + cache = { enabled, expiresAt: now + ttl }; + return enabled; +} diff --git a/apps/api/src/services/trial/oauth-hook.ts b/apps/api/src/services/trial/oauth-hook.ts new file mode 100644 index 000000000..7f2a06d18 --- /dev/null +++ b/apps/api/src/services/trial/oauth-hook.ts @@ -0,0 +1,110 @@ +/** + * OAuth callback trial-claim hook. + * + * After BetterAuth handles a successful `GET /api/auth/callback/github`, this + * hook inspects the request for a trial fingerprint cookie and — if the + * fingerprint binds to an active unclaimed trial — issues a signed + * `sam_trial_claim` cookie and rewrites the OAuth redirect to the trial's + * claim landing page (`https://app.${BASE_DOMAIN}/try/:trialId?claim=1`). + * + * Nothing in BetterAuth's flow is mutated: we only mutate the Response it + * returns. If the user is not mid-trial, the response is returned untouched. + */ + +import { TRIAL_COOKIE_FINGERPRINT_NAME } from '@simple-agent-manager/shared'; + +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import { + buildClaimCookie, + DEFAULT_TRIAL_CLAIM_TTL_MS, + signClaimToken, + type TrialClaimPayload, + verifyFingerprint, +} from './cookies'; +import { readTrialByFingerprint } from './trial-store'; + +const CALLBACK_PATH_SUFFIX = '/callback/github'; + +/** + * Wrap a BetterAuth response. Returns the original Response if no trial + * fingerprint is detected, or a cloned+modified response that sets the + * claim cookie and redirects to the trial claim landing page. + */ +export async function maybeAttachTrialClaimCookie( + env: Env, + request: Request, + response: Response +): Promise { + // Only run on GitHub OAuth callback + const url = new URL(request.url); + if (!url.pathname.endsWith(CALLBACK_PATH_SUFFIX)) return response; + + // BetterAuth signals a successful OAuth callback by issuing a 302 redirect + // (to the origin URL). We piggyback on that — on 4xx/5xx we leave the + // response alone. + if (response.status < 300 || response.status >= 400) return response; + + const secret = env.TRIAL_CLAIM_TOKEN_SECRET; + if (!secret) return response; + + const cookieHeader = request.headers.get('cookie') ?? ''; + const fingerprintCookie = parseCookie(cookieHeader, TRIAL_COOKIE_FINGERPRINT_NAME); + if (!fingerprintCookie) return response; + + const uuid = await verifyFingerprint(fingerprintCookie, secret); + if (!uuid) return response; + + const record = await readTrialByFingerprint(env, uuid); + if (!record || record.claimed) return response; + if (record.expiresAt <= Date.now()) return response; + + // Build the claim cookie + const now = Date.now(); + const payload: TrialClaimPayload = { + trialId: record.trialId, + projectId: record.projectId, + issuedAt: now, + expiresAt: now + DEFAULT_TRIAL_CLAIM_TTL_MS, + }; + const token = await signClaimToken(payload, secret); + const cookieDomain = env.BASE_DOMAIN ? `.${env.BASE_DOMAIN}` : undefined; + const cookie = buildClaimCookie(token, { domain: cookieDomain }); + + // Rewrite the redirect Location to the app's claim landing page. + const baseDomain = env.BASE_DOMAIN; + const claimUrl = `https://app.${baseDomain}/try/${record.trialId}?claim=1`; + + // Copy existing headers, set/append Set-Cookie and overwrite Location. + const headers = new Headers(response.headers); + headers.append('Set-Cookie', cookie); + headers.set('Location', claimUrl); + + log.info('trial_oauth_hook.claim_cookie_attached', { + trialId: record.trialId, + projectId: record.projectId, + }); + + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers, + }); +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function parseCookie(header: string, name: string): string | null { + if (!header) return null; + const parts = header.split(/;\s*/); + for (const part of parts) { + const eq = part.indexOf('='); + if (eq === -1) continue; + if (part.slice(0, eq) === name) { + return decodeURIComponent(part.slice(eq + 1)); + } + } + return null; +} diff --git a/apps/api/src/services/trial/trial-runner.ts b/apps/api/src/services/trial/trial-runner.ts new file mode 100644 index 000000000..b8c182f02 --- /dev/null +++ b/apps/api/src/services/trial/trial-runner.ts @@ -0,0 +1,231 @@ +/** + * Trial Runner — bootstraps the anonymous discovery agent for a trial. + * + * Invoked by Track A after the trial workspace is provisioned and reachable. + * Responsibilities: + * 1. Resolve the correct agent type based on deployment mode (staging vs production). + * 2. Create a chat session + ACP session on the trial's project. + * 3. Seed the discovery prompt as the initial prompt. + * 4. Return the ACP session id so Track A can wire the VM-agent bootstrap. + * + * Event streaming: the Track B SSE endpoint (/api/trial/:trialId/events) polls + * the TrialEventBus DO; callers that produce trial events should invoke + * `emitTrialEvent()` to fan them out. Bridge from ACP session notifications to + * the event bus is wired in Track A / ACP status handlers (separate concern). + */ + +import type { Env } from '../../env'; +import { log } from '../../lib/logger'; +import * as projectDataService from '../project-data'; +import { DISCOVERY_PROMPT, DISCOVERY_PROMPT_VERSION } from './discovery-prompt'; + +// --------------------------------------------------------------------------- +// Defaults (Principle XI) +// --------------------------------------------------------------------------- + +/** Deployment mode — staging uses cheaper models, production uses Anthropic. */ +const DEFAULT_ENVIRONMENT = 'staging'; +const DEFAULT_TRIAL_AGENT_TYPE_STAGING = 'opencode'; +const DEFAULT_TRIAL_AGENT_TYPE_PRODUCTION = 'claude-code'; +const DEFAULT_TRIAL_MODEL_STAGING = '@cf/meta/llama-4-scout-17b-16e-instruct'; +const DEFAULT_TRIAL_MODEL_PRODUCTION = 'claude-sonnet-4-5'; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface TrialRunnerConfig { + /** "staging" | "production". Falls back to ENVIRONMENT env var then DEFAULT. */ + mode: 'staging' | 'production'; + /** Agent runtime to launch in the workspace (e.g. opencode, claude-code). */ + agentType: string; + /** Model name / provider-specific identifier for the discovery agent. */ + model: string; + /** LLM provider that the VM agent should route requests through. */ + provider: 'anthropic' | 'workers-ai'; +} + +export interface StartDiscoveryAgentOptions { + /** Project id — already created by Track A under `system_anonymous_trials`. */ + projectId: string; + /** Workspace id bound to the trial. Used for session->workspace linkage. */ + workspaceId: string; + /** Short human-facing topic recorded on the chat session (e.g. repo slug). */ + sessionTopic?: string | null; +} + +export interface StartDiscoveryAgentResult { + chatSessionId: string; + acpSessionId: string; + agentType: string; + model: string; + provider: 'anthropic' | 'workers-ai'; + promptVersion: string; +} + +// --------------------------------------------------------------------------- +// Config resolution +// --------------------------------------------------------------------------- + +export function resolveTrialRunnerConfig(env: Env): TrialRunnerConfig { + const modeStr = (env.ENVIRONMENT ?? DEFAULT_ENVIRONMENT).toLowerCase(); + const mode: 'staging' | 'production' = + modeStr === 'production' || modeStr === 'prod' ? 'production' : 'staging'; + + const agentType = + mode === 'production' + ? (env.TRIAL_AGENT_TYPE_PRODUCTION ?? DEFAULT_TRIAL_AGENT_TYPE_PRODUCTION) + : (env.TRIAL_AGENT_TYPE_STAGING ?? DEFAULT_TRIAL_AGENT_TYPE_STAGING); + + const providerOverride = env.TRIAL_LLM_PROVIDER?.toLowerCase(); + const provider: 'anthropic' | 'workers-ai' = + providerOverride === 'anthropic' || providerOverride === 'workers-ai' + ? providerOverride + : mode === 'production' + ? 'anthropic' + : 'workers-ai'; + + const model = + env.TRIAL_MODEL ?? + (provider === 'anthropic' ? DEFAULT_TRIAL_MODEL_PRODUCTION : DEFAULT_TRIAL_MODEL_STAGING); + + return { mode, agentType, model, provider }; +} + +// --------------------------------------------------------------------------- +// Start the discovery agent +// --------------------------------------------------------------------------- + +/** + * Create chat + ACP sessions on the trial project and seed the discovery prompt. + * + * Does NOT actually boot the VM-agent subprocess — Track A owns the workspace + * lifecycle and will call the VM agent once the workspace is provisioned. This + * function only records the intent in Durable Object state so it can be picked + * up from the project dashboard after the trial is claimed. + */ +export async function startDiscoveryAgent( + env: Env, + opts: StartDiscoveryAgentOptions +): Promise { + const config = resolveTrialRunnerConfig(env); + + // Validate production config: Anthropic mode REQUIRES the API key. + if (config.mode === 'production' && config.provider === 'anthropic') { + if (!env.ANTHROPIC_API_KEY_TRIAL) { + throw new Error( + 'ANTHROPIC_API_KEY_TRIAL is required for production trial runner (Anthropic provider)' + ); + } + } + + // Chat session — groups messages/activity on the project page. + const chatSessionId = await projectDataService.createSession( + env, + opts.projectId, + opts.workspaceId, + opts.sessionTopic ?? 'Exploring repository', + null /* taskId */ + ); + + // ACP session — represents the agent run in the workspace. + const acpSession = await projectDataService.createAcpSession( + env, + opts.projectId, + chatSessionId, + DISCOVERY_PROMPT, + config.agentType + ); + + log.info('trial_runner.discovery_started', { + projectId: opts.projectId, + workspaceId: opts.workspaceId, + chatSessionId, + acpSessionId: acpSession.id, + agentType: config.agentType, + model: config.model, + provider: config.provider, + promptVersion: DISCOVERY_PROMPT_VERSION, + }); + + return { + chatSessionId, + acpSessionId: acpSession.id, + agentType: config.agentType, + model: config.model, + provider: config.provider, + promptVersion: DISCOVERY_PROMPT_VERSION, + }; +} + +// --------------------------------------------------------------------------- +// Event emission helpers (for use by any code path that produces TrialEvents) +// --------------------------------------------------------------------------- + +import type { TrialEvent } from '@simple-agent-manager/shared'; + +import { readTrial, readTrialByProject } from './trial-store'; + +/** + * Append a TrialEvent to the trial's event bus DO. + * Safe to call from any code path — silently no-ops if the trial does not exist + * or the bus is already closed (terminal event emitted). + */ +export async function emitTrialEvent( + env: Env, + trialId: string, + event: TrialEvent +): Promise { + log.info('trial_event_bus.emit_begin', { trialId, type: event.type }); + try { + const id = env.TRIAL_EVENT_BUS.idFromName(trialId); + const stub = env.TRIAL_EVENT_BUS.get(id); + const resp = await stub.fetch('https://trial-event-bus/append', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify(event), + }); + if (!resp.ok && resp.status !== 409 /* closed */) { + log.warn('trial_event_bus.append_failed', { + trialId, + type: event.type, + status: resp.status, + }); + } else { + log.info('trial_event_bus.emit_ok', { + trialId, + type: event.type, + status: resp.status, + }); + } + } catch (err) { + log.warn('trial_event_bus.append_error', { + trialId, + type: event.type, + error: err instanceof Error ? err.message : String(err), + stack: err instanceof Error ? err.stack : undefined, + }); + } +} + +/** + * Look up the active trialId for a given projectId and emit the event. Used by + * ACP notification hooks that only know the project id. + */ +export async function emitTrialEventForProject( + env: Env, + projectId: string, + event: TrialEvent +): Promise { + const record = await readTrialByProject(env, projectId); + if (!record) return; + await emitTrialEvent(env, record.trialId, event); +} + +/** Look up trial record by id — passthrough for callers that don't import trial-store directly. */ +export async function getTrialRecord( + env: Env, + trialId: string +): Promise> { + return readTrial(env, trialId); +} diff --git a/apps/api/src/services/trial/trial-store.ts b/apps/api/src/services/trial/trial-store.ts new file mode 100644 index 000000000..1759e000e --- /dev/null +++ b/apps/api/src/services/trial/trial-store.ts @@ -0,0 +1,131 @@ +/** + * Trial metadata store (KV-backed). + * + * Track A (`POST /api/trial/create`) writes trial records here; Track B + * (`GET /api/trial/:trialId/events` and `POST /api/trial/claim`) reads them. + * + * Keys: + * trial:${trialId} -> TrialRecord (JSON) + * trial-by-project:${projectId} -> trialId (string) + * trial-by-fingerprint:${fpUuid} -> trialId (string; most-recent active trial) + * + * Per-key TTL is derived from `expiresAt` so stale records are evicted + * automatically by Cloudflare KV. + */ + +import type { Env } from '../../env'; + +/** + * The canonical trial record. Only fields both tracks agree on live here — any + * track-specific metadata (VM IP, workspace URL) should live on the workspace + * row or the project row. + */ +export interface TrialRecord { + trialId: string; + projectId: string; + /** Decoded fingerprint UUID (NOT the signed cookie value). */ + fingerprint: string; + /** workspace id owned by the trial, if provisioned. Null before workspace_ready. */ + workspaceId: string | null; + /** Public GitHub repo URL that triggered the trial. */ + repoUrl: string; + /** epoch ms — when the trial was created (POST /api/trial/create). */ + createdAt: number; + /** epoch ms — hard expiry; matches TRIAL_WORKSPACE_TTL_MS. */ + expiresAt: number; + /** true once the user has claimed the project (OAuth callback -> POST /claim). */ + claimed: boolean; +} + +// --------------------------------------------------------------------------- +// Key helpers +// --------------------------------------------------------------------------- + +export function trialKey(trialId: string): string { + return `trial:${trialId}`; +} +export function trialByProjectKey(projectId: string): string { + return `trial-by-project:${projectId}`; +} +export function trialByFingerprintKey(fingerprint: string): string { + return `trial-by-fingerprint:${fingerprint}`; +} + +// --------------------------------------------------------------------------- +// Reads +// --------------------------------------------------------------------------- + +export async function readTrial( + env: Env, + trialId: string +): Promise { + const raw = await env.KV.get(trialKey(trialId)); + if (!raw) return null; + try { + return JSON.parse(raw) as TrialRecord; + } catch { + return null; + } +} + +export async function readTrialByProject( + env: Env, + projectId: string +): Promise { + const trialId = await env.KV.get(trialByProjectKey(projectId)); + if (!trialId) return null; + return readTrial(env, trialId); +} + +export async function readTrialByFingerprint( + env: Env, + fingerprint: string +): Promise { + const trialId = await env.KV.get(trialByFingerprintKey(fingerprint)); + if (!trialId) return null; + return readTrial(env, trialId); +} + +// --------------------------------------------------------------------------- +// Writes +// --------------------------------------------------------------------------- + +/** + * Persist/overwrite a trial record, plus its two index keys. Expiration is + * auto-set to `record.expiresAt` so KV purges stale records. + */ +export async function writeTrial(env: Env, record: TrialRecord): Promise { + const nowSec = Math.floor(Date.now() / 1000); + const ttlSec = Math.max(60, Math.floor(record.expiresAt / 1000) - nowSec); + const json = JSON.stringify(record); + const writes: Promise[] = [ + env.KV.put(trialKey(record.trialId), json, { expirationTtl: ttlSec }), + env.KV.put(trialByFingerprintKey(record.fingerprint), record.trialId, { + expirationTtl: ttlSec, + }), + ]; + // Only index by project when a real projectId exists. Track A creates the + // trial before provisioning the project, so the first write carries an empty + // projectId — we'd otherwise collide all pending trials on `trial-by-project:`. + if (record.projectId) { + writes.push( + env.KV.put(trialByProjectKey(record.projectId), record.trialId, { + expirationTtl: ttlSec, + }) + ); + } + await Promise.all(writes); +} + +/** Mark a trial as claimed (idempotent). */ +export async function markTrialClaimed( + env: Env, + trialId: string +): Promise { + const record = await readTrial(env, trialId); + if (!record) return null; + if (record.claimed) return record; + const updated: TrialRecord = { ...record, claimed: true }; + await writeTrial(env, updated); + return updated; +} diff --git a/apps/api/tests/unit/durable-objects/trial-counter.test.ts b/apps/api/tests/unit/durable-objects/trial-counter.test.ts new file mode 100644 index 000000000..80e7ff812 --- /dev/null +++ b/apps/api/tests/unit/durable-objects/trial-counter.test.ts @@ -0,0 +1,328 @@ +/** + * Unit tests for TrialCounter DO. + * + * Covers: + * - Incrementing from 0 up to the cap + * - Reaching the cap returns { ok: false, count: } and does NOT mutate + * - cap = 0 disables the limit (any number of increments allowed) + * - decrement clamps at 0 + * - get() returns current state without mutating + * - Separate month keys maintain independent counters + * - Atomicity: transactionSync passes through sql.exec writes + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// Provide a minimal DurableObject base class so the implementation module can +// import from 'cloudflare:workers' in a node test runner. +vi.mock('cloudflare:workers', () => ({ + DurableObject: class { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ctx: any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + env: any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(ctx: any, env: any) { + this.ctx = ctx; + this.env = env; + } + }, +})); + +const { TrialCounter } = await import( + '../../../src/durable-objects/trial-counter' +); + +// --------------------------------------------------------------------------- +// In-memory SQL stub +// --------------------------------------------------------------------------- + +/** + * A very small fake of SqlStorage that handles the exact three queries the DO + * uses: CREATE TABLE IF NOT EXISTS, SELECT ... WHERE month_key = ?, and + * INSERT ... ON CONFLICT DO UPDATE. The counter state lives in a Map. + */ +function createFakeSql() { + const rows = new Map(); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const exec = vi.fn((query: string, ...args: any[]) => { + const q = query.trim().toLowerCase(); + + if (q.startsWith('create table')) { + return { toArray: () => [] }; + } + + if (q.startsWith('select')) { + const key = args[0] as string; + const count = rows.get(key); + return { + toArray: () => (count !== undefined ? [{ count }] : []), + }; + } + + if (q.startsWith('insert')) { + const [key, count] = args as [string, number]; + rows.set(key, count); + return { toArray: () => [] }; + } + + return { toArray: () => [] }; + }); + + return { exec, _rows: rows }; +} + +function createFakeCtx() { + const sql = createFakeSql(); + // transactionSync runs the callback synchronously and returns its result. + const transactionSync = vi.fn((fn: () => T): T => fn()); + // blockConcurrencyWhile runs the callback immediately (ignoring the promise + // ordering guarantee since this is a pure in-memory fake). + const blockConcurrencyWhile = vi.fn(async (fn: () => Promise): Promise => { + return fn(); + }); + return { + storage: { + sql, + transactionSync, + }, + blockConcurrencyWhile, + _sql: sql, + }; +} + +function createDO() { + const ctx = createFakeCtx(); + const env = {} as unknown as Env; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const instance = new TrialCounter(ctx as any, env); + return { instance, ctx }; +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +describe('TrialCounter DO', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('creates the trial_counter table on construction', async () => { + const { ctx } = createDO(); + // Allow the blockConcurrencyWhile callback to run. + await Promise.resolve(); + const execCalls = ctx._sql.exec.mock.calls.map((c: unknown[]) => (c[0] as string)); + expect( + execCalls.some((q) => /create table if not exists trial_counter/i.test(q)) + ).toBe(true); + }); + + it('increments from 0 up to the cap', async () => { + const { instance } = createDO(); + const r1 = await instance.increment('2026-04', 3); + const r2 = await instance.increment('2026-04', 3); + const r3 = await instance.increment('2026-04', 3); + expect(r1).toEqual({ ok: true, count: 1 }); + expect(r2).toEqual({ ok: true, count: 2 }); + expect(r3).toEqual({ ok: true, count: 3 }); + }); + + it('refuses to exceed the cap and does not mutate the counter', async () => { + const { instance } = createDO(); + await instance.increment('2026-04', 2); + await instance.increment('2026-04', 2); + + const blocked = await instance.increment('2026-04', 2); + expect(blocked).toEqual({ ok: false, count: 2 }); + + // Follow-on read must still show 2, not 3. + const state = await instance.get('2026-04'); + expect(state).toEqual({ monthKey: '2026-04', count: 2 }); + }); + + it('disables the cap when cap <= 0', async () => { + const { instance } = createDO(); + for (let i = 0; i < 10; i++) { + const r = await instance.increment('2026-04', 0); + expect(r.ok).toBe(true); + expect(r.count).toBe(i + 1); + } + }); + + it('treats negative caps the same as cap=0 (unlimited)', async () => { + const { instance } = createDO(); + const r = await instance.increment('2026-04', -1); + expect(r).toEqual({ ok: true, count: 1 }); + }); + + it('decrements and clamps at 0', async () => { + const { instance } = createDO(); + await instance.increment('2026-04', 10); + await instance.increment('2026-04', 10); + + expect(await instance.decrement('2026-04')).toBe(1); + expect(await instance.decrement('2026-04')).toBe(0); + // Clamped at zero — does NOT go negative. + expect(await instance.decrement('2026-04')).toBe(0); + expect(await instance.decrement('2026-04')).toBe(0); + }); + + it('decrement on an unknown month key yields 0 (no underflow)', async () => { + const { instance } = createDO(); + expect(await instance.decrement('2099-12')).toBe(0); + }); + + it('maintains independent counters per month key', async () => { + const { instance } = createDO(); + await instance.increment('2026-04', 5); + await instance.increment('2026-04', 5); + await instance.increment('2026-05', 5); + + expect(await instance.get('2026-04')).toEqual({ monthKey: '2026-04', count: 2 }); + expect(await instance.get('2026-05')).toEqual({ monthKey: '2026-05', count: 1 }); + // Cap on month A does not block month B. + const bump = await instance.increment('2026-05', 5); + expect(bump).toEqual({ ok: true, count: 2 }); + }); + + it('get() returns 0 for an unseen month without mutating storage', async () => { + const { instance, ctx } = createDO(); + const state = await instance.get('2099-01'); + expect(state).toEqual({ monthKey: '2099-01', count: 0 }); + // No insert happened for this key. + expect(ctx._sql._rows.has('2099-01')).toBe(false); + }); + + it('each mutating call runs inside transactionSync', async () => { + const { instance, ctx } = createDO(); + await instance.increment('2026-04', 10); + await instance.increment('2026-04', 10); + await instance.decrement('2026-04'); + // 2 increments + 1 decrement = 3 transactionSync invocations. + expect(ctx.storage.transactionSync).toHaveBeenCalledTimes(3); + }); + + describe('tryIncrement', () => { + it('returns { allowed, count } on success', async () => { + const { instance } = createDO(); + const r = await instance.tryIncrement('2026-04', 3); + expect(r).toEqual({ allowed: true, count: 1 }); + }); + + it('returns { allowed: false, count } when at cap', async () => { + const { instance } = createDO(); + await instance.tryIncrement('2026-04', 1); + const blocked = await instance.tryIncrement('2026-04', 1); + expect(blocked).toEqual({ allowed: false, count: 1 }); + }); + + it('shares storage with increment (same underlying counter)', async () => { + const { instance } = createDO(); + await instance.increment('2026-04', 10); + const r = await instance.tryIncrement('2026-04', 10); + expect(r).toEqual({ allowed: true, count: 2 }); + }); + }); + + describe('prune', () => { + it('deletes rows older than keepMonthKey and returns count', async () => { + // Extend the fake SQL with range predicate support for prune(). + const { instance, ctx } = createDO(); + // Seed several months. + await instance.increment('2026-01', 10); + await instance.increment('2026-02', 10); + await instance.increment('2026-03', 10); + await instance.increment('2026-04', 10); + + // Patch the fake exec to handle DELETE/COUNT with month_key < ? + const rows = ctx._sql._rows; + ctx._sql.exec.mockImplementation((query: string, ...args: unknown[]) => { + const q = query.trim().toLowerCase(); + if (q.startsWith('select count(*)')) { + const keep = args[0] as string; + const c = [...rows.keys()].filter((k) => k < keep).length; + return { toArray: () => [{ c }] }; + } + if (q.startsWith('delete from trial_counter where month_key <')) { + const keep = args[0] as string; + for (const k of [...rows.keys()]) if (k < keep) rows.delete(k); + return { toArray: () => [] }; + } + if (q.startsWith('select count')) { + // unused — handled above + return { toArray: () => [] }; + } + return { toArray: () => [] }; + }); + + const deleted = await instance.prune('2026-03'); + expect(deleted).toBe(2); + expect(rows.has('2026-01')).toBe(false); + expect(rows.has('2026-02')).toBe(false); + expect(rows.has('2026-03')).toBe(true); + expect(rows.has('2026-04')).toBe(true); + }); + }); + + describe('fetch HTTP surface', () => { + it('GET /state returns current counter', async () => { + const { instance } = createDO(); + await instance.increment('2026-04', 10); + const res = await instance.fetch( + new Request('https://do/state?monthKey=2026-04') + ); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ monthKey: '2026-04', count: 1 }); + }); + + it('GET /state without monthKey returns 400', async () => { + const { instance } = createDO(); + const res = await instance.fetch(new Request('https://do/state')); + expect(res.status).toBe(400); + }); + + it('POST /tryIncrement returns { allowed, count }', async () => { + const { instance } = createDO(); + const res = await instance.fetch( + new Request('https://do/tryIncrement', { + method: 'POST', + body: JSON.stringify({ monthKey: '2026-04', cap: 2 }), + }) + ); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ allowed: true, count: 1 }); + }); + + it('POST /tryIncrement without body returns 400', async () => { + const { instance } = createDO(); + const res = await instance.fetch( + new Request('https://do/tryIncrement', { + method: 'POST', + body: JSON.stringify({}), + }) + ); + expect(res.status).toBe(400); + }); + + it('POST /decrement returns { count }', async () => { + const { instance } = createDO(); + await instance.increment('2026-04', 10); + await instance.increment('2026-04', 10); + const res = await instance.fetch( + new Request('https://do/decrement', { + method: 'POST', + body: JSON.stringify({ monthKey: '2026-04' }), + }) + ); + expect(res.status).toBe(200); + expect(await res.json()).toEqual({ count: 1 }); + }); + + it('unknown path returns 404', async () => { + const { instance } = createDO(); + const res = await instance.fetch(new Request('https://do/nope')); + expect(res.status).toBe(404); + }); + }); +}); diff --git a/apps/api/tests/unit/durable-objects/trial-event-bus.test.ts b/apps/api/tests/unit/durable-objects/trial-event-bus.test.ts new file mode 100644 index 000000000..177494f14 --- /dev/null +++ b/apps/api/tests/unit/durable-objects/trial-event-bus.test.ts @@ -0,0 +1,195 @@ +/** + * Unit tests for TrialEventBus DO. + * + * Covers: + * - POST /append stores event and returns cursor + * - GET /poll with no events returns immediately after timeout with empty + * events array + * - Appending an event wakes a pending poller + * - GET /poll with cursor only returns events newer than cursor + * - Appending terminal event (trial.ready / trial.error) closes the bus and + * subsequent polls return closed=true + * - POST /close sets closed=true and wakes waiters + * - POST /append on a closed bus returns 409 + * - Buffer evicts oldest beyond MAX_BUFFERED_EVENTS (500) + */ +import { describe, expect, it, vi } from 'vitest'; + +// Base DO class shim +vi.mock('cloudflare:workers', () => ({ + DurableObject: class { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ctx: any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + env: any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(ctx: any, env: any) { + this.ctx = ctx; + this.env = env; + } + }, +})); + +const { TrialEventBus } = await import('../../../src/durable-objects/trial-event-bus'); + +function makeDO() { + const store = new Map(); + const ctx = { + storage: { + get: async (key: string): Promise => store.get(key) as T | undefined, + put: async (key: string, value: unknown): Promise => { store.set(key, value); }, + }, + } as unknown; + const env = {} as unknown; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return new TrialEventBus(ctx as any, env as any); +} + +function appendReq(event: unknown): Request { + return new Request('https://trial-event-bus/append', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify(event), + }); +} + +function pollReq(cursor = 0, timeoutMs = 100): Request { + return new Request( + `https://trial-event-bus/poll?cursor=${cursor}&timeoutMs=${timeoutMs}`, + { method: 'GET' } + ); +} + +function closeReq(): Request { + return new Request('https://trial-event-bus/close', { method: 'POST' }); +} + +describe('TrialEventBus DO', () => { + it('POST /append stores an event and returns a cursor', async () => { + const bus = makeDO(); + const r = await bus.fetch(appendReq({ type: 'trial.progress', message: 'x', at: 1 })); + expect(r.status).toBe(200); + const body = (await r.json()) as { cursor: number }; + expect(body.cursor).toBe(1); + }); + + it('POST /append rejects invalid JSON', async () => { + const bus = makeDO(); + const req = new Request('https://trial-event-bus/append', { + method: 'POST', + body: 'not-json', + }); + const r = await bus.fetch(req); + expect(r.status).toBe(400); + }); + + it('GET /poll with empty buffer waits and returns empty events', async () => { + const bus = makeDO(); + const r = await bus.fetch(pollReq(0, 100)); + expect(r.status).toBe(200); + const body = (await r.json()) as { + events: unknown[]; + cursor: number; + closed: boolean; + }; + expect(body.events).toEqual([]); + expect(body.closed).toBe(false); + }); + + it('GET /poll returns events newer than cursor', async () => { + const bus = makeDO(); + await bus.fetch(appendReq({ type: 'trial.progress', message: 'a', at: 1 })); + await bus.fetch(appendReq({ type: 'trial.progress', message: 'b', at: 2 })); + + const r = await bus.fetch(pollReq(1, 100)); + const body = (await r.json()) as { + events: { cursor: number; event: { message: string } }[]; + }; + expect(body.events.length).toBe(1); + expect(body.events[0].event.message).toBe('b'); + }); + + it('Appending trial.ready does NOT close the bus (agent continues producing events)', async () => { + const bus = makeDO(); + await bus.fetch(appendReq({ type: 'trial.ready', projectId: 'p', at: 1 })); + + const poll = await bus.fetch(pollReq(0, 100)); + const body = (await poll.json()) as { closed: boolean; events: unknown[] }; + expect(body.closed).toBe(false); + expect(body.events.length).toBe(1); + }); + + it('Appending a terminal event (trial.error) closes the bus', async () => { + const bus = makeDO(); + await bus.fetch( + appendReq({ type: 'trial.error', error: 'kaboom', message: 'x', at: 1 }) + ); + + const poll = await bus.fetch(pollReq(0, 100)); + const body = (await poll.json()) as { closed: boolean }; + expect(body.closed).toBe(true); + }); + + it('POST /close sets closed=true', async () => { + const bus = makeDO(); + const r = await bus.fetch(closeReq()); + expect(r.status).toBe(200); + + const poll = await bus.fetch(pollReq(0, 100)); + const body = (await poll.json()) as { closed: boolean }; + expect(body.closed).toBe(true); + }); + + it('POST /append on a closed bus returns 409', async () => { + const bus = makeDO(); + await bus.fetch(closeReq()); + + const r = await bus.fetch(appendReq({ type: 'trial.progress', message: 'x', at: 0 })); + expect(r.status).toBe(409); + }); + + it('Appending an event wakes a pending long-poll', async () => { + const bus = makeDO(); + // Start a long poll with 2s timeout but we'll append within 50ms. + const pollPromise = bus.fetch(pollReq(0, 2000)); + + // Give the poll a tick to register the waiter. + await new Promise((r) => setTimeout(r, 10)); + await bus.fetch(appendReq({ type: 'trial.progress', message: 'wake', at: 1 })); + + const start = Date.now(); + const resp = await pollPromise; + const elapsed = Date.now() - start; + + expect(elapsed).toBeLessThan(500); // short-circuited, not timed out + const body = (await resp.json()) as { + events: { event: { message: string } }[]; + }; + expect(body.events.length).toBe(1); + expect(body.events[0].event.message).toBe('wake'); + }); + + it('POST /close wakes pending pollers', async () => { + const bus = makeDO(); + const pollPromise = bus.fetch(pollReq(0, 2000)); + + await new Promise((r) => setTimeout(r, 10)); + await bus.fetch(closeReq()); + + const start = Date.now(); + const resp = await pollPromise; + const elapsed = Date.now() - start; + + expect(elapsed).toBeLessThan(500); + const body = (await resp.json()) as { closed: boolean }; + expect(body.closed).toBe(true); + }); + + it('unknown path returns 404', async () => { + const bus = makeDO(); + const r = await bus.fetch( + new Request('https://trial-event-bus/garbage', { method: 'GET' }) + ); + expect(r.status).toBe(404); + }); +}); diff --git a/apps/api/tests/unit/durable-objects/trial-orchestrator-agent-boot.test.ts b/apps/api/tests/unit/durable-objects/trial-orchestrator-agent-boot.test.ts new file mode 100644 index 000000000..ff33b7af3 --- /dev/null +++ b/apps/api/tests/unit/durable-objects/trial-orchestrator-agent-boot.test.ts @@ -0,0 +1,421 @@ +/** + * Capability tests for the TrialOrchestrator's discovery_agent_start step and + * the `fetchDefaultBranch` GitHub probe used during project_creation. + * + * These assert the class-of-bug regressions described in + * tasks/active/2026-04-19-trial-orchestrator-actually-start-agent.md: + * + * 1. handleDiscoveryAgentStart previously only created ACP session rows and + * never told the VM agent to actually launch the subprocess. The ACP + * session sat in `pending` forever, so `trial.ready` never fired. + * 2. The orchestrator hardcoded `'main'` for both projects.default_branch + * and the workspace's `git clone --branch`, breaking trials on + * master-default repos (e.g. octocat/Hello-World). + * + * The tests mock the cross-boundary calls (node-agent HTTP, MCP token KV, + * ProjectData DO transitions, GitHub API) and assert the payloads + the + * idempotency flags. + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +const { startDiscoveryAgentMock, resolveTrialRunnerConfigMock, emitTrialEventMock } = vi.hoisted( + () => ({ + startDiscoveryAgentMock: vi.fn(), + resolveTrialRunnerConfigMock: vi.fn(() => ({ + mode: 'staging' as const, + agentType: 'opencode', + model: '@cf/meta/llama-4-scout-17b-16e-instruct', + provider: 'workers-ai' as const, + })), + emitTrialEventMock: vi.fn(async () => {}), + }), +); +vi.mock('../../../src/services/trial/trial-runner', () => ({ + emitTrialEvent: emitTrialEventMock, + emitTrialEventForProject: vi.fn(async () => {}), + startDiscoveryAgent: startDiscoveryAgentMock, + resolveTrialRunnerConfig: resolveTrialRunnerConfigMock, +})); + +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrial: vi.fn(async () => null), + readTrialByProject: vi.fn(async () => null), + writeTrial: vi.fn(async () => {}), +})); + +const { linkSessionToWorkspaceMock, transitionAcpSessionMock } = vi.hoisted(() => ({ + linkSessionToWorkspaceMock: vi.fn(async () => {}), + transitionAcpSessionMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/project-data', () => ({ + linkSessionToWorkspace: linkSessionToWorkspaceMock, + transitionAcpSession: transitionAcpSessionMock, +})); + +const { createAgentSessionOnNodeMock, startAgentSessionOnNodeMock, createWorkspaceOnNodeMock } = + vi.hoisted(() => ({ + createAgentSessionOnNodeMock: vi.fn(async () => {}), + startAgentSessionOnNodeMock: vi.fn(async () => {}), + createWorkspaceOnNodeMock: vi.fn(async () => {}), + })); +vi.mock('../../../src/services/node-agent', () => ({ + createAgentSessionOnNode: createAgentSessionOnNodeMock, + startAgentSessionOnNode: startAgentSessionOnNodeMock, + createWorkspaceOnNode: createWorkspaceOnNodeMock, +})); + +const { generateMcpTokenMock, storeMcpTokenMock } = vi.hoisted(() => ({ + generateMcpTokenMock: vi.fn(() => 'mcp_tok_fixture_abc123'), + storeMcpTokenMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/mcp-token', () => ({ + generateMcpToken: generateMcpTokenMock, + storeMcpToken: storeMcpTokenMock, +})); + +const { handleDiscoveryAgentStart, handleProjectCreation } = await import( + '../../../src/durable-objects/trial-orchestrator/steps' +); + +type Storage = Map; + +function makeState(overrides: Record = {}) { + return { + version: 1, + trialId: 'trial_boot_test', + repoUrl: 'https://github.com/octocat/Hello-World', + repoOwner: 'octocat', + repoName: 'Hello-World', + currentStep: 'discovery_agent_start', + projectId: 'proj_X', + nodeId: 'node_X', + autoProvisionedNode: false, + workspaceId: 'ws_X', + chatSessionId: null, + acpSessionId: null, + defaultBranch: null, + mcpToken: null, + agentSessionCreatedOnVm: false, + agentStartedOnVm: false, + acpAssignedOnVm: false, + acpRunningOnVm: false, + retryCount: 0, + createdAt: Date.now(), + lastStepAt: Date.now(), + nodeAgentReadyStartedAt: null, + workspaceReadyStartedAt: null, + completed: false, + failureReason: null, + ...overrides, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any; +} + +function makeCtx(storage: Storage = new Map()) { + return { + storage: { + get: vi.fn(async (k: string) => storage.get(k)), + put: vi.fn(async (k: string, v: unknown) => { + storage.set(k, v); + }), + }, + _storage: storage, + }; +} + +function makeRc(ctx: ReturnType, advanced: string[]) { + return { + env: { + BASE_DOMAIN: 'sammy.party', + KV: { put: vi.fn(async () => {}) }, + DATABASE: { + prepare: vi.fn(() => ({ + bind: vi.fn(() => ({ run: vi.fn(async () => {}) })), + })), + }, + TRIAL_ANONYMOUS_USER_ID: 'u_anon_trial', + } as unknown as Parameters[1]['env'], + ctx: ctx as unknown as Parameters[1]['ctx'], + advanceToStep: vi.fn(async (state, step: string) => { + advanced.push(step); + state.currentStep = step; + state.lastStepAt = Date.now(); + await ctx.storage.put('state', state); + }), + getAgentReadyTimeoutMs: () => 60_000, + getWorkspaceReadyTimeoutMs: () => 180_000, + getWorkspaceReadyPollIntervalMs: () => 5_000, + getNodeReadyTimeoutMs: () => 180_000, + getHeartbeatSkewMs: () => 30_000, + } as unknown as Parameters[1]; +} + +describe('handleDiscoveryAgentStart — VM agent boot', () => { + beforeEach(() => { + vi.clearAllMocks(); + startDiscoveryAgentMock.mockResolvedValue({ + chatSessionId: 'cs_new', + acpSessionId: 'acp_new', + agentType: 'opencode', + model: '@cf/meta/llama-4-scout-17b-16e-instruct', + provider: 'workers-ai' as const, + promptVersion: 'v1', + }); + }); + + it('calls createAgentSessionOnNode with the correct payload after startDiscoveryAgent', async () => { + const ctx = makeCtx(); + const rc = makeRc(ctx, []); + const state = makeState(); + + await handleDiscoveryAgentStart(state, rc); + + expect(startDiscoveryAgentMock).toHaveBeenCalledTimes(1); + expect(createAgentSessionOnNodeMock).toHaveBeenCalledTimes(1); + const [nodeId, workspaceId, acpSessionId, label, , userId, chatSessionId, projectId] = + createAgentSessionOnNodeMock.mock.calls[0]; + expect(nodeId).toBe('node_X'); + expect(workspaceId).toBe('ws_X'); + expect(acpSessionId).toBe('acp_new'); + expect(label).toContain('octocat/Hello-World'); + expect(userId).toBe('u_anon_trial'); + expect(chatSessionId).toBe('cs_new'); + expect(projectId).toBe('proj_X'); + + expect(state.agentSessionCreatedOnVm).toBe(true); + }); + + it('mints + stores an MCP token keyed on trialId (synthetic taskId)', async () => { + const ctx = makeCtx(); + const rc = makeRc(ctx, []); + const state = makeState(); + + await handleDiscoveryAgentStart(state, rc); + + expect(generateMcpTokenMock).toHaveBeenCalledTimes(1); + expect(storeMcpTokenMock).toHaveBeenCalledTimes(1); + const [, token, data] = storeMcpTokenMock.mock.calls[0]; + expect(token).toBe('mcp_tok_fixture_abc123'); + // Trial synthetic taskId = trialId. + expect((data as { taskId: string }).taskId).toBe(state.trialId); + expect((data as { projectId: string }).projectId).toBe('proj_X'); + expect((data as { workspaceId: string }).workspaceId).toBe('ws_X'); + + expect(state.mcpToken).toBe('mcp_tok_fixture_abc123'); + }); + + it('calls startAgentSessionOnNode with the discovery prompt + MCP server URL', async () => { + const ctx = makeCtx(); + const rc = makeRc(ctx, []); + const state = makeState(); + + await handleDiscoveryAgentStart(state, rc); + + expect(startAgentSessionOnNodeMock).toHaveBeenCalledTimes(1); + const [nodeId, workspaceId, acpSessionId, agentType, initialPrompt, , userId, mcpServer] = + startAgentSessionOnNodeMock.mock.calls[0]; + expect(nodeId).toBe('node_X'); + expect(workspaceId).toBe('ws_X'); + expect(acpSessionId).toBe('acp_new'); + expect(agentType).toBe('opencode'); + expect(typeof initialPrompt).toBe('string'); + expect(initialPrompt as string).toContain('octocat/Hello-World'); + expect(userId).toBe('u_anon_trial'); + expect(mcpServer).toEqual({ + url: 'https://api.sammy.party/mcp', + token: 'mcp_tok_fixture_abc123', + }); + + expect(state.agentStartedOnVm).toBe(true); + }); + + it('drives the ACP session pending → assigned → running → advances to running', async () => { + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState(); + + await handleDiscoveryAgentStart(state, rc); + + // Two transitions in order: assigned, then running. Both from the + // orchestrator (actorType: system). + expect(transitionAcpSessionMock).toHaveBeenCalledTimes(2); + expect(transitionAcpSessionMock.mock.calls[0][3]).toBe('assigned'); + expect(transitionAcpSessionMock.mock.calls[1][3]).toBe('running'); + expect(transitionAcpSessionMock.mock.calls[0][4]).toMatchObject({ actorType: 'system' }); + + expect(state.acpAssignedOnVm).toBe(true); + expect(state.acpRunningOnVm).toBe(true); + + // Finally, the step advances the orchestrator FSM to the `running` step + // (terminal) so no further alarms fire. + expect(advanced).toEqual(['running']); + }); + + it('is idempotent across crash-and-retry — re-entering with flags set does not re-call VM', async () => { + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + // Simulate a crash after every idempotency flag has already been set. + const state = makeState({ + chatSessionId: 'cs_persist', + acpSessionId: 'acp_persist', + mcpToken: 'mcp_tok_persist', + agentSessionCreatedOnVm: true, + agentStartedOnVm: true, + acpAssignedOnVm: true, + acpRunningOnVm: true, + }); + + await handleDiscoveryAgentStart(state, rc); + + // None of the side-effectful cross-boundary calls should fire again. + expect(startDiscoveryAgentMock).not.toHaveBeenCalled(); + expect(createAgentSessionOnNodeMock).not.toHaveBeenCalled(); + expect(generateMcpTokenMock).not.toHaveBeenCalled(); + expect(storeMcpTokenMock).not.toHaveBeenCalled(); + expect(startAgentSessionOnNodeMock).not.toHaveBeenCalled(); + expect(transitionAcpSessionMock).not.toHaveBeenCalled(); + + // But the state machine MUST still advance to `running` so the alarm loop + // doesn't re-fire the step forever. + expect(advanced).toEqual(['running']); + }); + + it('re-entering after a partial crash (mcpToken set but subprocess not started) resumes from step 4', async () => { + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + chatSessionId: 'cs_partial', + acpSessionId: 'acp_partial', + agentSessionCreatedOnVm: true, + mcpToken: 'mcp_tok_partial', + agentStartedOnVm: false, // crashed before step 4 + acpAssignedOnVm: false, + acpRunningOnVm: false, + }); + + await handleDiscoveryAgentStart(state, rc); + + // Steps 1–3 must NOT fire again. + expect(startDiscoveryAgentMock).not.toHaveBeenCalled(); + expect(createAgentSessionOnNodeMock).not.toHaveBeenCalled(); + expect(generateMcpTokenMock).not.toHaveBeenCalled(); + + // Steps 4–5 MUST fire to complete the flow. + expect(startAgentSessionOnNodeMock).toHaveBeenCalledTimes(1); + expect(transitionAcpSessionMock).toHaveBeenCalledTimes(2); + // MCP token passed to startAgentSessionOnNode is the persisted one, not a fresh mint. + const startArgs = startAgentSessionOnNodeMock.mock.calls[0]; + expect(startArgs[7]).toMatchObject({ token: 'mcp_tok_partial' }); + + expect(advanced).toEqual(['running']); + }); +}); + +describe('handleProjectCreation — default branch detection', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('uses GitHub API default_branch (master) when the probe succeeds', async () => { + const fetchSpy = vi.spyOn(globalThis, 'fetch').mockResolvedValue( + new Response(JSON.stringify({ default_branch: 'master' }), { + status: 200, + headers: { 'content-type': 'application/json' }, + }), + ); + + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + currentStep: 'project_creation', + projectId: null, + workspaceId: null, + nodeId: null, + defaultBranch: null, + }); + + await handleProjectCreation(state, rc); + + expect(state.defaultBranch).toBe('master'); + expect(fetchSpy).toHaveBeenCalledTimes(1); + // Request targets the canonical GitHub public repos endpoint. + const req = fetchSpy.mock.calls[0][0]; + const url = typeof req === 'string' ? req : (req as Request | URL).toString(); + expect(url).toContain('api.github.com/repos/octocat/Hello-World'); + + fetchSpy.mockRestore(); + }); + + it('falls back to "main" when the probe fails (network error)', async () => { + const fetchSpy = vi + .spyOn(globalThis, 'fetch') + .mockRejectedValue(new Error('network down')); + + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + currentStep: 'project_creation', + projectId: null, + workspaceId: null, + nodeId: null, + defaultBranch: null, + }); + + await handleProjectCreation(state, rc); + + expect(state.defaultBranch).toBe('main'); + fetchSpy.mockRestore(); + }); + + it('falls back to "main" when GitHub returns 404', async () => { + const fetchSpy = vi.spyOn(globalThis, 'fetch').mockResolvedValue( + new Response('Not Found', { status: 404 }), + ); + + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + currentStep: 'project_creation', + projectId: null, + workspaceId: null, + nodeId: null, + defaultBranch: null, + }); + + await handleProjectCreation(state, rc); + + expect(state.defaultBranch).toBe('main'); + fetchSpy.mockRestore(); + }); + + it('skips the probe when state.defaultBranch is already set (idempotent re-entry)', async () => { + const fetchSpy = vi.spyOn(globalThis, 'fetch'); + + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + currentStep: 'project_creation', + projectId: null, + workspaceId: null, + nodeId: null, + defaultBranch: 'trunk', + }); + + await handleProjectCreation(state, rc); + + expect(state.defaultBranch).toBe('trunk'); + expect(fetchSpy).not.toHaveBeenCalled(); + fetchSpy.mockRestore(); + }); +}); diff --git a/apps/api/tests/unit/durable-objects/trial-orchestrator-steps.test.ts b/apps/api/tests/unit/durable-objects/trial-orchestrator-steps.test.ts new file mode 100644 index 000000000..2f7a60a21 --- /dev/null +++ b/apps/api/tests/unit/durable-objects/trial-orchestrator-steps.test.ts @@ -0,0 +1,239 @@ +/** + * Unit tests for TrialOrchestrator step handlers. + * + * Covers the highest-value invariants that do not depend on D1/DO plumbing: + * - `handleRunning` marks state completed (terminal-for-orchestrator) + * - `handleDiscoveryAgentStart` throws a permanent error when the required + * projectId / workspaceId are missing (invariant: we never start an agent + * without its target workspace). + * - `handleDiscoveryAgentStart` idempotency: already-linked session skips + * the `startDiscoveryAgent` call and advances straight to `running`. + * + * Broader per-handler coverage (project_creation D1 inserts, node selection + * branching, workspace readiness polling, etc.) is tracked separately in + * tasks/backlog/2026-04-19-trial-orchestrator-step-handler-coverage.md — those + * paths require mocking drizzle + node provisioning + project-data services, + * which is out of scope for this PR. + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +// Stub the trial-runner so handleDiscoveryAgentStart doesn't reach the real +// worker-side session bootstrap. +const { startDiscoveryAgentMock, emitTrialEventMock } = vi.hoisted(() => ({ + startDiscoveryAgentMock: vi.fn(), + emitTrialEventMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/trial/trial-runner', () => ({ + emitTrialEvent: emitTrialEventMock, + emitTrialEventForProject: vi.fn(async () => {}), + startDiscoveryAgent: startDiscoveryAgentMock, + resolveTrialRunnerConfig: vi.fn(() => ({ + mode: 'staging' as const, + agentType: 'opencode', + model: '@cf/meta/llama-4-scout-17b-16e-instruct', + provider: 'workers-ai' as const, + })), +})); + +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrial: vi.fn(async () => null), + readTrialByProject: vi.fn(async () => null), + writeTrial: vi.fn(async () => {}), +})); + +vi.mock('../../../src/services/project-data', () => ({ + linkSessionToWorkspace: vi.fn(async () => {}), + transitionAcpSession: vi.fn(async () => {}), +})); + +vi.mock('../../../src/services/node-agent', () => ({ + createAgentSessionOnNode: vi.fn(async () => {}), + startAgentSessionOnNode: vi.fn(async () => {}), + createWorkspaceOnNode: vi.fn(async () => {}), +})); + +vi.mock('../../../src/services/mcp-token', () => ({ + generateMcpToken: vi.fn(() => 'mcp_tok_idempotent'), + storeMcpToken: vi.fn(async () => {}), +})); + +// Mock services/nodes so handleNodeProvisioning doesn't reach real provider code. +const { createNodeRecordMock, provisionNodeMock } = vi.hoisted(() => ({ + createNodeRecordMock: vi.fn(), + provisionNodeMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/nodes', () => ({ + createNodeRecord: createNodeRecordMock, + provisionNode: provisionNodeMock, +})); + +// getRuntimeLimits returns a small fixture — handleNodeProvisioning only reads +// `nodeHeartbeatStaleSeconds` for the createNodeRecord call. +vi.mock('../../../src/services/limits', () => ({ + getRuntimeLimits: vi.fn(() => ({ nodeHeartbeatStaleSeconds: 120 })), +})); + +const { handleRunning, handleDiscoveryAgentStart, handleNodeProvisioning } = await import( + '../../../src/durable-objects/trial-orchestrator/steps' +); + +type Storage = Map; + +function makeState(overrides: Record = {}) { + return { + version: 1, + trialId: 'trial_steps_test', + repoUrl: '', + repoOwner: 'alice', + repoName: 'repo', + currentStep: 'discovery_agent_start', + projectId: null, + nodeId: null, + autoProvisionedNode: false, + workspaceId: null, + chatSessionId: null, + acpSessionId: null, + retryCount: 0, + createdAt: Date.now(), + lastStepAt: Date.now(), + nodeAgentReadyStartedAt: null, + workspaceReadyStartedAt: null, + completed: false, + failureReason: null, + ...overrides, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any; +} + +function makeCtx(storage: Storage = new Map()) { + return { + storage: { + get: vi.fn(async (k: string) => storage.get(k)), + put: vi.fn(async (k: string, v: unknown) => { + storage.set(k, v); + }), + }, + _storage: storage, + }; +} + +function makeRc(ctx: ReturnType, advanced: string[]) { + return { + env: { + DATABASE: { + prepare: vi.fn(() => ({ + bind: vi.fn(() => ({ run: vi.fn(async () => {}) })), + })), + }, + } as unknown as Parameters[1]['env'], + ctx: ctx as unknown as Parameters[1]['ctx'], + advanceToStep: vi.fn(async (state, step: string) => { + advanced.push(step); + state.currentStep = step; + state.lastStepAt = Date.now(); + await ctx.storage.put('state', state); + }), + getAgentReadyTimeoutMs: () => 60_000, + getWorkspaceReadyTimeoutMs: () => 180_000, + getWorkspaceReadyPollIntervalMs: () => 5_000, + getNodeReadyTimeoutMs: () => 180_000, + getHeartbeatSkewMs: () => 30_000, + } as unknown as Parameters[1]; +} + +describe('handleRunning', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('marks state.completed = true and persists', async () => { + const ctx = makeCtx(); + const rc = makeRc(ctx, []); + const state = makeState({ currentStep: 'running' }); + await handleRunning(state, rc); + expect(state.completed).toBe(true); + expect(ctx.storage.put).toHaveBeenCalledWith('state', state); + }); +}); + +describe('handleNodeProvisioning', () => { + beforeEach(() => { + vi.clearAllMocks(); + createNodeRecordMock.mockResolvedValue({ id: 'node_new_123' }); + provisionNodeMock.mockResolvedValue(undefined); + }); + + // Regression for the async-IP provider bug: provisionNode() returns while + // the node is still in 'creating' status for Scaleway/GCP (VM boots, IP + // arrives on first heartbeat). The step MUST advance to `node_agent_ready` + // unconditionally — the heartbeat polling in that step is what waits for + // the VM to come up. Synchronously requiring status='running' here would + // force every async-IP trial through the retry/backoff cycle until the + // heartbeat landed, wasting the retry budget and risking permanent failure. + it('advances to node_agent_ready even when provisionNode leaves status=creating', async () => { + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + currentStep: 'node_provisioning', + nodeId: null, + autoProvisionedNode: false, + }); + + await handleNodeProvisioning(state, rc); + + expect(createNodeRecordMock).toHaveBeenCalledTimes(1); + expect(provisionNodeMock).toHaveBeenCalledWith('node_new_123', expect.anything()); + expect(state.nodeId).toBe('node_new_123'); + expect(state.autoProvisionedNode).toBe(true); + expect(advanced).toEqual(['node_agent_ready']); + }); +}); + +describe('handleDiscoveryAgentStart', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('throws a permanent error when projectId, workspaceId, or nodeId is missing', async () => { + const ctx = makeCtx(); + const rc = makeRc(ctx, []); + const state = makeState({ projectId: null, workspaceId: null, nodeId: null }); + + let caught: Error & { permanent?: boolean } = new Error('never'); + try { + await handleDiscoveryAgentStart(state, rc); + } catch (err) { + caught = err as Error & { permanent?: boolean }; + } + expect(caught.message).toMatch(/projectId, workspaceId, and nodeId/); + expect(caught.permanent).toBe(true); + // startDiscoveryAgent must NOT have been called. + expect(startDiscoveryAgentMock).not.toHaveBeenCalled(); + }); + + it('is idempotent: already-booted session skips all VM calls and advances to running', async () => { + const ctx = makeCtx(); + const advanced: string[] = []; + const rc = makeRc(ctx, advanced); + const state = makeState({ + projectId: 'proj_X', + workspaceId: 'ws_X', + nodeId: 'node_X', + chatSessionId: 'cs_X', + acpSessionId: 'acp_X', + mcpToken: 'mcp_tok_X', + agentSessionCreatedOnVm: true, + agentStartedOnVm: true, + acpAssignedOnVm: true, + acpRunningOnVm: true, + }); + await handleDiscoveryAgentStart(state, rc); + expect(startDiscoveryAgentMock).not.toHaveBeenCalled(); + expect(advanced).toEqual(['running']); + }); +}); diff --git a/apps/api/tests/unit/durable-objects/trial-orchestrator.test.ts b/apps/api/tests/unit/durable-objects/trial-orchestrator.test.ts new file mode 100644 index 000000000..767c6e897 --- /dev/null +++ b/apps/api/tests/unit/durable-objects/trial-orchestrator.test.ts @@ -0,0 +1,546 @@ +/** + * Unit tests for TrialOrchestrator DO. + * + * Covers: + * - `start()` is idempotent: a second call with the same input no-ops and + * does not re-schedule the alarm. + * - `start()` persists initial state with currentStep='project_creation'. + * - `alarm()` on a completed state is a no-op (terminal guard). + * - `alarm()` on overall-timeout emits trial.error and marks completed. + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// Base DO class shim — the real Cloudflare DurableObject base is only +// available at runtime in Workers; the shim gives us a constructible class. +vi.mock('cloudflare:workers', () => ({ + DurableObject: class { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ctx: any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + env: any; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(ctx: any, env: any) { + this.ctx = ctx; + this.env = env; + } + }, +})); + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +// Stub the trial-runner event emitter to observe failTrial → trial.error. +const { emitTrialEventMock } = vi.hoisted(() => ({ + emitTrialEventMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/trial/trial-runner', () => ({ + emitTrialEvent: emitTrialEventMock, + emitTrialEventForProject: vi.fn(async () => {}), +})); + +// Stub trial-store so failTrial's bookkeeping path doesn't blow up on KV. +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrial: vi.fn(async () => null), + readTrialByProject: vi.fn(async () => null), + writeTrial: vi.fn(async () => {}), +})); + +// Stub mcp-token service so failTrial's revocation path is observable. +const { revokeMcpTokenMock } = vi.hoisted(() => ({ + revokeMcpTokenMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/mcp-token', () => ({ + revokeMcpToken: revokeMcpTokenMock, +})); + +// Stub every step handler so alarm() dispatch can be controlled per-test. +// Individual tests override these via `stepMocks..mockImplementationOnce`. +const { stepMocks } = vi.hoisted(() => ({ + stepMocks: { + handleProjectCreation: vi.fn(async () => {}), + handleNodeSelection: vi.fn(async () => {}), + handleNodeProvisioning: vi.fn(async () => {}), + handleNodeAgentReady: vi.fn(async () => {}), + handleWorkspaceCreation: vi.fn(async () => {}), + handleWorkspaceReady: vi.fn(async () => {}), + handleDiscoveryAgentStart: vi.fn(async () => {}), + handleRunning: vi.fn(async () => {}), + }, +})); +vi.mock('../../../src/durable-objects/trial-orchestrator/steps', () => stepMocks); + +const { TrialOrchestrator } = await import( + '../../../src/durable-objects/trial-orchestrator' +); + +type Storage = Map; + +function makeCtx(storage: Storage = new Map()) { + let alarmTime: number | null = null; + return { + storage: { + get: vi.fn(async (k: string) => storage.get(k)), + put: vi.fn(async (k: string, v: unknown) => { + storage.set(k, v); + }), + delete: vi.fn(async (k: string) => storage.delete(k)), + setAlarm: vi.fn(async (t: number) => { + alarmTime = t; + }), + getAlarm: vi.fn(async () => alarmTime), + deleteAlarm: vi.fn(async () => { + alarmTime = null; + }), + }, + _alarmTime: () => alarmTime, + _storage: storage, + }; +} + +function makeEnv() { + return { + DATABASE: {}, + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'stub-id'), + get: vi.fn(() => ({ + fetch: vi.fn(async () => new Response('ok')), + })), + }, + } as unknown as Parameters[1]; +} + +describe('TrialOrchestrator.start()', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('persists initial state and schedules an alarm on first call', async () => { + const ctx = makeCtx(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.start({ + trialId: 'trial_abc', + repoUrl: 'https://github.com/alice/repo', + repoOwner: 'alice', + repoName: 'repo', + }); + const stored = ctx._storage.get('state') as { currentStep: string; trialId: string }; + expect(stored).toBeTruthy(); + expect(stored.currentStep).toBe('project_creation'); + expect(stored.trialId).toBe('trial_abc'); + expect(ctx._alarmTime()).not.toBeNull(); + }); + + it('emits trial.started event so the SSE stream signals immediately', async () => { + const ctx = makeCtx(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.start({ + trialId: 'trial_started_evt', + repoUrl: 'https://github.com/alice/repo', + repoOwner: 'alice', + repoName: 'repo', + }); + expect(emitTrialEventMock).toHaveBeenCalledTimes(1); + const [env, trialId, event] = emitTrialEventMock.mock.calls[0]; + expect(env).toBeTruthy(); + expect(trialId).toBe('trial_started_evt'); + expect(event).toMatchObject({ + type: 'trial.started', + trialId: 'trial_started_evt', + repoUrl: 'https://github.com/alice/repo', + }); + expect(typeof (event as { startedAt: number }).startedAt).toBe('number'); + }); + + it('is idempotent — second start() call is a no-op (no re-schedule)', async () => { + const ctx = makeCtx(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.start({ + trialId: 'trial_abc', + repoUrl: 'https://github.com/alice/repo', + repoOwner: 'alice', + repoName: 'repo', + }); + const firstPutCount = ctx.storage.put.mock.calls.length; + const firstAlarmCount = ctx.storage.setAlarm.mock.calls.length; + + // Second call — must not re-persist or re-alarm. + await orch.start({ + trialId: 'trial_abc', + repoUrl: 'https://github.com/alice/repo', + repoOwner: 'alice', + repoName: 'repo', + }); + + expect(ctx.storage.put.mock.calls.length).toBe(firstPutCount); + expect(ctx.storage.setAlarm.mock.calls.length).toBe(firstAlarmCount); + }); +}); + +describe('TrialOrchestrator.alarm()', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('is a no-op when state is completed (terminal guard)', async () => { + const storage: Storage = new Map(); + storage.set('state', { + version: 1, + trialId: 'trial_done', + currentStep: 'failed', + completed: true, + createdAt: Date.now(), + lastStepAt: Date.now(), + }); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + // Must not have emitted a new event. + expect(emitTrialEventMock).not.toHaveBeenCalled(); + }); + + it('capability: start() → overall-timeout alarm() emits trial.error through event bus', async () => { + // End-to-end capability: exercises the full DO state machine from + // `start()` through a terminal `alarm()`, asserting that: + // 1. start() persists initial state + schedules alarm + // 2. start() emits trial.started via the event bus + // 3. alarm() with an expired overall budget transitions to `failed` + // 4. alarm() emits trial.error via the same event bus + // + // This is the cross-boundary capability test required by rule 10: + // it verifies that the orchestrator DO (via the mocked `emitTrialEvent` + // seam) actually pushes events into the downstream TrialEventBus that + // the SSE route reads from. Mocking `emitTrialEvent` is acceptable + // here because the companion route tests (trial-events.test.ts) + // verify the bus → SSE side of the same seam. + const ctx = makeCtx(); + const env = makeEnv(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, env); + + // 1. start() — persists state + alarm + emits trial.started. + await orch.start({ + trialId: 'trial_cap_e2e', + repoUrl: 'https://github.com/alice/repo', + repoOwner: 'alice', + repoName: 'repo', + }); + expect(emitTrialEventMock).toHaveBeenCalledTimes(1); + expect(emitTrialEventMock.mock.calls[0][2]).toMatchObject({ + type: 'trial.started', + trialId: 'trial_cap_e2e', + }); + expect(ctx._alarmTime()).not.toBeNull(); + + // 2. Simulate the overall-timeout budget expiring by rewinding createdAt. + const state = ctx._storage.get('state') as { createdAt: number; lastStepAt: number }; + state.createdAt = Date.now() - 24 * 60 * 60 * 1000; + state.lastStepAt = state.createdAt; + + // 3. alarm() — detects timeout, fails the trial, emits trial.error. + await orch.alarm(); + + const finalState = ctx._storage.get('state') as { + currentStep: string; + completed: boolean; + failureReason: string | null; + }; + expect(finalState.currentStep).toBe('failed'); + expect(finalState.completed).toBe(true); + expect(finalState.failureReason).toMatch(/timed out/); + + // 4. trial.error was dispatched through the event bus seam. + const errorCall = emitTrialEventMock.mock.calls.find( + (c) => (c[2] as { type: string }).type === 'trial.error' + ); + expect(errorCall).toBeTruthy(); + expect(errorCall?.[1]).toBe('trial_cap_e2e'); + }); + + it('fails the trial with timeout error when overall budget exceeded', async () => { + const storage: Storage = new Map(); + const farPast = Date.now() - 24 * 60 * 60 * 1000; // 24h ago + storage.set('state', { + version: 1, + trialId: 'trial_slow', + repoUrl: '', + repoOwner: '', + repoName: '', + currentStep: 'workspace_ready', + projectId: 'proj_1', + nodeId: null, + autoProvisionedNode: false, + workspaceId: null, + chatSessionId: null, + acpSessionId: null, + retryCount: 0, + createdAt: farPast, + lastStepAt: farPast, + nodeAgentReadyStartedAt: null, + workspaceReadyStartedAt: null, + completed: false, + failureReason: null, + }); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + + const updated = ctx._storage.get('state') as { + currentStep: string; + completed: boolean; + failureReason: string | null; + }; + expect(updated.currentStep).toBe('failed'); + expect(updated.completed).toBe(true); + expect(updated.failureReason).toMatch(/timed out/); + }); +}); + +// --------------------------------------------------------------------------- +// alarm() step-error retry/backoff branches +// +// These exercise the catch block at index.ts:190–220, which is the only path +// that exercises isTransientError(), retry counting, backoff scheduling, and +// the permanent-error fallback via failTrial. +// --------------------------------------------------------------------------- + +function makeRunningState(overrides: Partial> = {}) { + const now = Date.now(); + return { + version: 1, + trialId: 'trial_retry', + repoUrl: '', + repoOwner: '', + repoName: '', + currentStep: 'project_creation', + projectId: null, + nodeId: null, + autoProvisionedNode: false, + workspaceId: null, + chatSessionId: null, + acpSessionId: null, + retryCount: 0, + createdAt: now, + lastStepAt: now, + nodeAgentReadyStartedAt: null, + workspaceReadyStartedAt: null, + completed: false, + failureReason: null, + ...overrides, + }; +} + +describe('TrialOrchestrator.alarm() — step error retry/backoff', () => { + beforeEach(() => { + vi.clearAllMocks(); + // Reset each step mock to a successful no-op. + for (const fn of Object.values(stepMocks)) { + fn.mockReset(); + fn.mockImplementation(async () => {}); + } + }); + + it('transient error with retries remaining increments retryCount and schedules backoff (no failTrial)', async () => { + const storage: Storage = new Map(); + storage.set('state', makeRunningState()); + stepMocks.handleProjectCreation.mockRejectedValueOnce( + new Error('fetch failed — network timeout') + ); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + + const updated = ctx._storage.get('state') as { + currentStep: string; + completed: boolean; + retryCount: number; + }; + // State advances retry counter but does NOT complete. + expect(updated.currentStep).toBe('project_creation'); + expect(updated.completed).toBe(false); + expect(updated.retryCount).toBe(1); + // A backoff alarm is scheduled. + expect(ctx._alarmTime()).not.toBeNull(); + // No trial.error fired — retry path is internal. + const errorEmit = emitTrialEventMock.mock.calls.find( + (c) => (c[2] as { type: string }).type === 'trial.error' + ); + expect(errorEmit).toBeUndefined(); + }); + + it('permanent error fails the trial immediately regardless of retry budget', async () => { + const storage: Storage = new Map(); + storage.set('state', makeRunningState({ retryCount: 0 })); + // Messages containing "invalid" classify as permanent per isTransientError. + stepMocks.handleProjectCreation.mockRejectedValueOnce( + new Error('invalid configuration: missing sentinel') + ); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + + const updated = ctx._storage.get('state') as { + currentStep: string; + completed: boolean; + failureReason: string | null; + }; + expect(updated.currentStep).toBe('failed'); + expect(updated.completed).toBe(true); + expect(updated.failureReason).toMatch(/invalid/); + + const errorEmit = emitTrialEventMock.mock.calls.find( + (c) => (c[2] as { type: string }).type === 'trial.error' + ); + expect(errorEmit).toBeTruthy(); + }); + + it('transient error with retries exhausted promotes to permanent failure', async () => { + const storage: Storage = new Map(); + // Budget exhausted: retryCount already at the default max (5). + storage.set('state', makeRunningState({ retryCount: 99 })); + stepMocks.handleProjectCreation.mockRejectedValueOnce( + new Error('fetch failed — upstream 503') + ); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + + const updated = ctx._storage.get('state') as { + currentStep: string; + completed: boolean; + }; + expect(updated.currentStep).toBe('failed'); + expect(updated.completed).toBe(true); + }); + + it('returns early when state has not been initialized (null-state guard)', async () => { + // No `state` key in storage — alarm fires before start() has run. + const ctx = makeCtx(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + // No events emitted, no step handler called, no alarm rescheduled. + expect(emitTrialEventMock).not.toHaveBeenCalled(); + expect(stepMocks.handleProjectCreation).not.toHaveBeenCalled(); + }); +}); + +// --------------------------------------------------------------------------- +// Security boundary tests — MCP token lifecycle + redaction +// Covers security-auditor HIGH findings from PR #760 follow-up review. +// --------------------------------------------------------------------------- + +describe('TrialOrchestrator — MCP token security boundary', () => { + beforeEach(() => { + vi.clearAllMocks(); + for (const fn of Object.values(stepMocks)) { + fn.mockReset(); + fn.mockImplementation(async () => {}); + } + revokeMcpTokenMock.mockReset(); + revokeMcpTokenMock.mockImplementation(async () => {}); + }); + + it('failTrial revokes state.mcpToken and clears it from persisted state', async () => { + const storage: Storage = new Map(); + storage.set( + 'state', + makeRunningState({ + mcpToken: 'tok_live_secret_xyz', + // Force permanent failure path so failTrial runs synchronously. + retryCount: 99, + }), + ); + stepMocks.handleProjectCreation.mockRejectedValueOnce( + new Error('invalid configuration — forces failTrial'), + ); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + + expect(revokeMcpTokenMock).toHaveBeenCalledTimes(1); + expect(revokeMcpTokenMock.mock.calls[0][1]).toBe('tok_live_secret_xyz'); + + // Post-revocation, state.mcpToken must be cleared so a later read cannot + // leak the now-dead token through getStatus or any other surface. + const updated = ctx._storage.get('state') as { mcpToken: string | null }; + expect(updated.mcpToken).toBeNull(); + }); + + it('failTrial tolerates revokeMcpToken errors and still emits trial.error', async () => { + const storage: Storage = new Map(); + storage.set( + 'state', + makeRunningState({ + mcpToken: 'tok_live_flaky_kv', + retryCount: 99, + }), + ); + stepMocks.handleProjectCreation.mockRejectedValueOnce( + new Error('invalid — forces failTrial'), + ); + revokeMcpTokenMock.mockRejectedValueOnce(new Error('KV hiccup')); + + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + await orch.alarm(); + + // Failure emission must still happen even though revoke threw. + const errorEmit = emitTrialEventMock.mock.calls.find( + (c) => (c[2] as { type: string }).type === 'trial.error' + ); + expect(errorEmit).toBeTruthy(); + const updated = ctx._storage.get('state') as { + currentStep: string; + completed: boolean; + }; + expect(updated.currentStep).toBe('failed'); + expect(updated.completed).toBe(true); + }); + + it('getStatus() redacts mcpToken so debug surfaces cannot leak the bearer credential', async () => { + const storage: Storage = new Map(); + storage.set( + 'state', + makeRunningState({ + mcpToken: 'tok_should_never_leak', + currentStep: 'running', + completed: true, + }), + ); + const ctx = makeCtx(storage); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + const status = await orch.getStatus(); + expect(status).not.toBeNull(); + expect(status!.mcpToken).toBe('[redacted]'); + // Other state fields should pass through so getStatus stays useful for + // debugging (currentStep, completed). + expect(status!.currentStep).toBe('running'); + expect(status!.completed).toBe(true); + + // Defence-in-depth: raw storage still has the real token (revocation is + // the caller's responsibility) — the redaction is a response-shaping + // guard, not a state mutation. + const raw = ctx._storage.get('state') as { mcpToken: string | null }; + expect(raw.mcpToken).toBe('tok_should_never_leak'); + }); + + it('getStatus() returns null when state is uninitialized (no accidental leak)', async () => { + const ctx = makeCtx(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const orch = new TrialOrchestrator(ctx as any, makeEnv()); + const status = await orch.getStatus(); + expect(status).toBeNull(); + }); +}); diff --git a/apps/api/tests/unit/routes/ai-proxy.test.ts b/apps/api/tests/unit/routes/ai-proxy.test.ts index 3701051ec..835b21b0c 100644 --- a/apps/api/tests/unit/routes/ai-proxy.test.ts +++ b/apps/api/tests/unit/routes/ai-proxy.test.ts @@ -1,13 +1,11 @@ /** - * Unit tests for the AI proxy route (AI Gateway pass-through). + * Unit tests for the AI proxy route (AI Gateway pass-through + Anthropic translation). * - * Tests model ID resolution/normalization and allowlist parsing. - * Response format tests are no longer needed — the Gateway returns - * standard OpenAI format and we pass it through transparently. + * Tests model ID resolution/normalization, allowlist parsing, and Anthropic model detection. */ import { describe, expect, it } from 'vitest'; -import { resolveModelId } from '../../../src/routes/ai-proxy'; +import { isAnthropicModel, resolveModelId } from '../../../src/routes/ai-proxy'; // ============================================================================= // Model Allowlist Parsing (extracted logic test) @@ -20,6 +18,8 @@ describe('model allowlist parsing', () => { raw.split(',').map((m) => m.trim()).filter(Boolean).map((m) => { let resolved = m; if (resolved.startsWith('workers-ai/')) resolved = resolved.slice('workers-ai/'.length); + // Anthropic models don't get @cf/ prefix + if (resolved.startsWith('claude-')) return resolved; if (!resolved.startsWith('@cf/') && !resolved.startsWith('@hf/')) resolved = `@cf/${resolved}`; return resolved; }), @@ -43,6 +43,12 @@ describe('model allowlist parsing', () => { const models = parseAndNormalizeModels('@cf/model-a,,@cf/model-b,'); expect(models.size).toBe(2); }); + + it('preserves Anthropic model IDs without @cf/ prefix', () => { + const models = parseAndNormalizeModels('claude-haiku-4-5-20251001,@cf/meta/llama-4-scout-17b-16e-instruct'); + expect(models.has('claude-haiku-4-5-20251001')).toBe(true); + expect(models.has('@cf/meta/llama-4-scout-17b-16e-instruct')).toBe(true); + }); }); // ============================================================================= @@ -50,31 +56,94 @@ describe('model allowlist parsing', () => { // ============================================================================= describe('resolveModelId', () => { - const mockEnv = { + /** Mock env with a KV stub that always returns null (no admin override). */ + const mockKV = { get: async () => null } as unknown as KVNamespace; + + const mockEnvWorkersAI = { AI_PROXY_DEFAULT_MODEL: '@cf/meta/llama-4-scout-17b-16e-instruct', + KV: mockKV, } as Parameters[1]; - it('returns default when model is undefined', () => { - expect(resolveModelId(undefined, mockEnv)).toBe('@cf/meta/llama-4-scout-17b-16e-instruct'); + const mockEnvAnthropic = { + AI_PROXY_DEFAULT_MODEL: 'claude-haiku-4-5-20251001', + KV: mockKV, + } as Parameters[1]; + + it('returns default when model is undefined (Workers AI default)', async () => { + expect(await resolveModelId(undefined, mockEnvWorkersAI)).toBe('@cf/meta/llama-4-scout-17b-16e-instruct'); + }); + + it('returns default when model is undefined (Anthropic default)', async () => { + expect(await resolveModelId(undefined, mockEnvAnthropic)).toBe('claude-haiku-4-5-20251001'); }); - it('returns model as-is when @cf/ prefix present', () => { - expect(resolveModelId('@cf/qwen/qwen3-30b-a3b-fp8', mockEnv)) + it('returns model as-is when @cf/ prefix present', async () => { + expect(await resolveModelId('@cf/qwen/qwen3-30b-a3b-fp8', mockEnvWorkersAI)) .toBe('@cf/qwen/qwen3-30b-a3b-fp8'); }); - it('strips workers-ai/ prefix', () => { - expect(resolveModelId('workers-ai/@cf/qwen/qwen3-30b-a3b-fp8', mockEnv)) + it('strips workers-ai/ prefix', async () => { + expect(await resolveModelId('workers-ai/@cf/qwen/qwen3-30b-a3b-fp8', mockEnvWorkersAI)) .toBe('@cf/qwen/qwen3-30b-a3b-fp8'); }); - it('adds @cf/ prefix when missing (OpenCode strips it)', () => { - expect(resolveModelId('meta/llama-4-scout-17b-16e-instruct', mockEnv)) + it('adds @cf/ prefix when missing (OpenCode strips it)', async () => { + expect(await resolveModelId('meta/llama-4-scout-17b-16e-instruct', mockEnvWorkersAI)) .toBe('@cf/meta/llama-4-scout-17b-16e-instruct'); }); - it('preserves @hf/ prefix for HuggingFace models', () => { - expect(resolveModelId('@hf/some/model', mockEnv)) + it('preserves @hf/ prefix for HuggingFace models', async () => { + expect(await resolveModelId('@hf/some/model', mockEnvWorkersAI)) .toBe('@hf/some/model'); }); + + it('preserves Anthropic model IDs without adding @cf/ prefix', async () => { + expect(await resolveModelId('claude-haiku-4-5-20251001', mockEnvWorkersAI)) + .toBe('claude-haiku-4-5-20251001'); + }); + + it('preserves full Anthropic model IDs with date suffix', async () => { + expect(await resolveModelId('claude-sonnet-4-5-20250514', mockEnvWorkersAI)) + .toBe('claude-sonnet-4-5-20250514'); + }); + + it('reads admin override from KV when no model specified', async () => { + const kvWithOverride = { + get: async () => JSON.stringify({ defaultModel: 'claude-haiku-4-5-20251001', updatedAt: '2026-04-20T00:00:00Z' }), + } as unknown as KVNamespace; + + const envWithKV = { ...mockEnvWorkersAI, KV: kvWithOverride }; + expect(await resolveModelId(undefined, envWithKV)).toBe('claude-haiku-4-5-20251001'); + }); + + it('explicit model overrides KV admin setting', async () => { + const kvWithOverride = { + get: async () => JSON.stringify({ defaultModel: 'claude-haiku-4-5-20251001', updatedAt: '2026-04-20T00:00:00Z' }), + } as unknown as KVNamespace; + + const envWithKV = { ...mockEnvWorkersAI, KV: kvWithOverride }; + expect(await resolveModelId('@cf/qwen/qwen3-30b-a3b-fp8', envWithKV)).toBe('@cf/qwen/qwen3-30b-a3b-fp8'); + }); +}); + +// ============================================================================= +// Anthropic Model Detection +// ============================================================================= + +describe('isAnthropicModel', () => { + it('identifies Claude models', () => { + expect(isAnthropicModel('claude-haiku-4-5-20251001')).toBe(true); + expect(isAnthropicModel('claude-sonnet-4-5-20250514')).toBe(true); + expect(isAnthropicModel('claude-opus-4-6')).toBe(true); + }); + + it('does not match Workers AI models', () => { + expect(isAnthropicModel('@cf/meta/llama-4-scout-17b-16e-instruct')).toBe(false); + expect(isAnthropicModel('@cf/qwen/qwen3-30b-a3b-fp8')).toBe(false); + }); + + it('does not match other providers', () => { + expect(isAnthropicModel('gpt-4o')).toBe(false); + expect(isAnthropicModel('gemini-pro')).toBe(false); + }); }); diff --git a/apps/api/tests/unit/routes/trial-claim.test.ts b/apps/api/tests/unit/routes/trial-claim.test.ts new file mode 100644 index 000000000..3be724647 --- /dev/null +++ b/apps/api/tests/unit/routes/trial-claim.test.ts @@ -0,0 +1,258 @@ +/** + * Unit tests for POST /api/trial/claim. + * + * Covers: + * - 401 when unauthenticated (handled via the auth middleware; tested by not + * installing the userId mock) + * - 500 when TRIAL_CLAIM_TOKEN_SECRET is unset + * - 400 when claim cookie missing + * - 400 when claim cookie signature fails / malformed / expired + * - 404 when trial record is absent + * - 400 when payload.projectId disagrees with stored record + * - 409 when trial already marked claimed in KV + * - 409 when D1 UPDATE affects zero rows (already claimed by someone else) + * - 200 happy path — D1 re-parented, KV markTrialClaimed called, + * Set-Cookie clears sam_trial_claim + */ +import { drizzle } from 'drizzle-orm/d1'; +import { Hono } from 'hono'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +// Auth — always return a user so the claim route runs its own logic. +vi.mock('../../../src/middleware/auth', () => ({ + requireAuth: () => vi.fn((_: unknown, next: () => Promise) => next()), + getUserId: () => 'user_claim_1', +})); + +// drizzle-orm/d1 — stub to control UPDATE outcome +vi.mock('drizzle-orm/d1'); + +// Trial-store — control the record returned +const { readTrialMock, markTrialClaimedMock } = vi.hoisted(() => ({ + readTrialMock: vi.fn(), + markTrialClaimedMock: vi.fn(), +})); +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrial: readTrialMock, + markTrialClaimed: markTrialClaimedMock, +})); + +// Silence logs +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +import { errors } from '../../../src/middleware/error'; +import { claimRoutes } from '../../../src/routes/trial/claim'; +import { signClaimToken } from '../../../src/services/trial/cookies'; + +const SECRET = 'test-secret-for-claim-route-32-bytes-long-minimum'; + +function makeApp(): Hono<{ Bindings: Env }> { + const app = new Hono<{ Bindings: Env }>(); + app.onError((err, c) => { + // Surface AppError as JSON (mirrors production onError) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const e = err as any; + if (typeof e.statusCode === 'number' && typeof e.toJSON === 'function') { + return c.json(e.toJSON(), e.statusCode); + } + return c.json({ error: 'INTERNAL', message: err.message }, 500); + }); + app.route('/api/trial', claimRoutes); + return app; +} + +function setDrizzleUpdateChanges(changes: number) { + const run = vi.fn().mockResolvedValue({ meta: { changes } }); + const chain = { + set: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + run, + }; + (drizzle as unknown as ReturnType).mockReturnValue({ + update: vi.fn().mockReturnValue(chain), + }); + return { run, chain }; +} + +function makeEnv(overrides: Partial = {}): Env { + return { + DATABASE: {} as D1Database, + TRIAL_CLAIM_TOKEN_SECRET: SECRET, + ...overrides, + } as unknown as Env; +} + +async function postClaim( + app: Hono<{ Bindings: Env }>, + cookie: string | null, + env: Env = makeEnv() +): Promise { + const headers: Record = { 'content-type': 'application/json' }; + if (cookie) headers['cookie'] = `sam_trial_claim=${encodeURIComponent(cookie)}`; + return app.request( + '/api/trial/claim', + { method: 'POST', headers, body: '{}' }, + env + ); +} + +function futurePayload( + overrides: Partial<{ + trialId: string; + projectId: string; + issuedAt: number; + expiresAt: number; + }> = {} +) { + const now = Date.now(); + return { + trialId: 'trial_good', + projectId: 'proj_good', + issuedAt: now, + expiresAt: now + 3600_000, + ...overrides, + }; +} + +describe('POST /api/trial/claim', () => { + beforeEach(() => { + readTrialMock.mockReset(); + markTrialClaimedMock.mockReset(); + vi.mocked(drizzle).mockReset(); + }); + + it('returns 500 when TRIAL_CLAIM_TOKEN_SECRET is unset', async () => { + const app = makeApp(); + const env = makeEnv({ TRIAL_CLAIM_TOKEN_SECRET: undefined } as Partial); + const resp = await postClaim(app, 'anything', env); + expect(resp.status).toBe(500); + }); + + it('returns 400 when claim cookie is missing', async () => { + const app = makeApp(); + const resp = await postClaim(app, null); + expect(resp.status).toBe(400); + const body = (await resp.json()) as { message?: string }; + expect(body.message).toMatch(/claim cookie/i); + }); + + it('returns 400 when claim cookie signature is invalid', async () => { + const app = makeApp(); + const resp = await postClaim(app, 'not-a-real-token.sig'); + expect(resp.status).toBe(400); + }); + + it('returns 400 when claim cookie is expired', async () => { + const app = makeApp(); + const token = await signClaimToken( + { ...futurePayload(), expiresAt: Date.now() - 1000 }, + SECRET + ); + const resp = await postClaim(app, token); + expect(resp.status).toBe(400); + }); + + it('returns 404 when trial record does not exist', async () => { + const app = makeApp(); + const token = await signClaimToken(futurePayload(), SECRET); + readTrialMock.mockResolvedValueOnce(null); + const resp = await postClaim(app, token); + expect(resp.status).toBe(404); + }); + + it('returns 400 when cookie projectId disagrees with record projectId', async () => { + const app = makeApp(); + const token = await signClaimToken( + futurePayload({ projectId: 'proj_cookie' }), + SECRET + ); + readTrialMock.mockResolvedValueOnce({ + trialId: 'trial_good', + projectId: 'proj_record_different', + claimed: false, + }); + const resp = await postClaim(app, token); + expect(resp.status).toBe(400); + }); + + it('returns 409 when trial is already claimed in KV', async () => { + const app = makeApp(); + const token = await signClaimToken(futurePayload(), SECRET); + readTrialMock.mockResolvedValueOnce({ + trialId: 'trial_good', + projectId: 'proj_good', + claimed: true, + }); + const resp = await postClaim(app, token); + expect(resp.status).toBe(409); + }); + + it('returns 409 when D1 UPDATE affects zero rows (race)', async () => { + const app = makeApp(); + const token = await signClaimToken(futurePayload(), SECRET); + readTrialMock.mockResolvedValueOnce({ + trialId: 'trial_good', + projectId: 'proj_good', + claimed: false, + }); + setDrizzleUpdateChanges(0); + const resp = await postClaim(app, token); + expect(resp.status).toBe(409); + }); + + it('returns 200 and clears claim cookie on successful re-parent', async () => { + const app = makeApp(); + const token = await signClaimToken(futurePayload(), SECRET); + readTrialMock.mockResolvedValueOnce({ + trialId: 'trial_good', + projectId: 'proj_good', + claimed: false, + }); + markTrialClaimedMock.mockResolvedValueOnce({ + trialId: 'trial_good', + projectId: 'proj_good', + claimed: true, + }); + const { run, chain } = setDrizzleUpdateChanges(1); + + const resp = await postClaim(app, token); + expect(resp.status).toBe(200); + const body = (await resp.json()) as { projectId: string; claimedAt: number }; + expect(body.projectId).toBe('proj_good'); + expect(body.claimedAt).toBeGreaterThan(0); + + // Drizzle was called: update with set + where + run + expect(chain.set).toHaveBeenCalledWith( + expect.objectContaining({ userId: 'user_claim_1' }) + ); + expect(run).toHaveBeenCalledTimes(1); + expect(markTrialClaimedMock).toHaveBeenCalledWith(expect.anything(), 'trial_good'); + + // Set-Cookie should clear the claim cookie + const setCookie = resp.headers.get('Set-Cookie'); + expect(setCookie).toContain('sam_trial_claim=;'); + expect(setCookie).toContain('Max-Age=0'); + }); + + it('still returns 200 when markTrialClaimed fails (best-effort)', async () => { + const app = makeApp(); + const token = await signClaimToken(futurePayload(), SECRET); + readTrialMock.mockResolvedValueOnce({ + trialId: 'trial_good', + projectId: 'proj_good', + claimed: false, + }); + markTrialClaimedMock.mockRejectedValueOnce(new Error('KV outage')); + setDrizzleUpdateChanges(1); + + const resp = await postClaim(app, token); + expect(resp.status).toBe(200); + }); +}); + +// Keep reference so the import isn't flagged as unused. +errors.badRequest; diff --git a/apps/api/tests/unit/routes/trial-create.ts.test.ts b/apps/api/tests/unit/routes/trial-create.ts.test.ts new file mode 100644 index 000000000..a6aa1747f --- /dev/null +++ b/apps/api/tests/unit/routes/trial-create.ts.test.ts @@ -0,0 +1,578 @@ +/** + * Behavioral tests for POST /api/trial/create. + * + * Verifies end-to-end integration across: + * - Valibot body validation (bad JSON → 400 invalid_url) + * - Kill-switch off → 503 trials_disabled + * - Missing TRIAL_CLAIM_TOKEN_SECRET → 503 trials_disabled + * - probeGithubRepo 404 → 404 repo_not_found + * - probeGithubRepo private → 403 repo_private + * - probeGithubRepo oversized → 413 repo_too_large + * - TrialCounter at cap → 429 cap_exceeded with waitlistResetsAt + * - TrialCounter RPC failure → 503 trials_disabled + * - Happy path → 201 with Set-Cookie headers and expected response shape + * - D1 insert failure → decrements counter slot (no burn) + * - Existing fingerprint cookie is reused for a second create + */ +import { Hono } from 'hono'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +const { isTrialsEnabledMock } = vi.hoisted(() => ({ + isTrialsEnabledMock: vi.fn(), +})); +vi.mock('../../../src/services/trial/kill-switch', () => ({ + isTrialsEnabled: isTrialsEnabledMock, +})); + +const { insertMock, valuesMock } = vi.hoisted(() => { + const valuesMock = vi.fn().mockResolvedValue(undefined); + const insertMock = vi.fn().mockReturnValue({ values: valuesMock }); + return { insertMock, valuesMock }; +}); + +vi.mock('drizzle-orm/d1', () => ({ + drizzle: vi.fn().mockReturnValue({ insert: insertMock }), +})); + +vi.mock('../../../src/db/schema', () => ({ + trials: { + id: 'id', + fingerprint: 'fingerprint', + repoUrl: 'repo_url', + monthKey: 'month_key', + status: 'status', + }, +})); + +const { createRoutes, probeGithubRepo } = await import( + '../../../src/routes/trial/create' +); + +const SECRET = 'test-secret-at-least-32-bytes-long-for-hmac-xx'; + +type Slot = { allowed: boolean; count: number }; + +function makeEnv(options: { + tryIncrementFn?: (monthKey: string, cap: number) => Promise; + decrementFn?: (monthKey: string) => Promise; + secret?: string | undefined; + cap?: string; + orchestratorStart?: (input: unknown) => Promise; +}): Env { + const orchestratorStartFn = + options.orchestratorStart ?? vi.fn(async () => {}); + const orchestratorStub = { start: orchestratorStartFn }; + return { + TRIAL_CLAIM_TOKEN_SECRET: 'secret' in options ? options.secret : SECRET, + TRIAL_MONTHLY_CAP: options.cap, + BASE_DOMAIN: 'example.com', + TRIAL_COUNTER: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => ({ + tryIncrement: + options.tryIncrementFn ?? + vi.fn(async () => ({ allowed: true, count: 1 })), + decrement: options.decrementFn ?? vi.fn(async () => 0), + })), + }, + TRIAL_ORCHESTRATOR: { + idFromName: vi.fn(() => 'orchestrator-do-id'), + get: vi.fn(() => orchestratorStub), + }, + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'event-bus-do-id'), + get: vi.fn(() => ({ + fetch: vi.fn(async () => new Response('ok')), + })), + }, + // Trial-store KV mirror — Track B readers (SSE events, claim) look trials + // up by trialId in KV. Minimal stub that accepts puts and returns null on + // reads, which is enough for the create-path tests. + KV: { + put: vi.fn(async () => undefined), + get: vi.fn(async () => null), + delete: vi.fn(async () => undefined), + }, + } as unknown as Env; +} + +function makeApp(env: Env, fetchFn?: typeof fetch) { + const app = new Hono<{ Bindings: Env }>(); + app.route('/api/trial', createRoutes); + // Inject a stub global.fetch used by probeGithubRepo inside the handler. + // (The route calls probeGithubRepo() with the default fetch.) + if (fetchFn) vi.stubGlobal('fetch', fetchFn); + return { app, env }; +} + +async function callCreate( + app: Hono<{ Bindings: Env }>, + env: Env, + body: unknown, + cookie?: string, + executionCtx?: ExecutionContext +) { + const headers: Record = { 'content-type': 'application/json' }; + if (cookie) headers.cookie = cookie; + return app.fetch( + new Request('https://api.test/api/trial/create', { + method: 'POST', + headers, + body: typeof body === 'string' ? body : JSON.stringify(body), + }), + env, + executionCtx + ); +} + +/** Minimal ExecutionContext stub that synchronously awaits every waitUntil + * promise so assertions can observe side effects after `callCreate` resolves. */ +function makeExecutionCtx(): ExecutionContext & { waitUntilPromises: Promise[] } { + const promises: Promise[] = []; + const ctx: ExecutionContext & { waitUntilPromises: Promise[] } = { + waitUntil: (p: Promise) => { + promises.push(p); + }, + passThroughOnException: () => {}, + props: {}, + waitUntilPromises: promises, + } as unknown as ExecutionContext & { waitUntilPromises: Promise[] }; + return ctx; +} + +function okGithubFetch(overrides: Partial<{ size: number; private: boolean }> = {}) { + return vi.fn().mockResolvedValue( + new Response( + JSON.stringify({ size: 50, private: false, ...overrides }), + { status: 200, headers: { 'content-type': 'application/json' } } + ) + ); +} + +describe('probeGithubRepo (unit)', () => { + it('returns ok for a public, small repo', async () => { + const fetchFn = okGithubFetch({ size: 100 }); + const res = await probeGithubRepo('a', 'b', { + maxKb: 500, + timeoutMs: 1000, + fetchFn, + }); + expect(res).toEqual({ ok: true, sizeKb: 100, private: false }); + }); + + it('returns repo_not_found on 404', async () => { + const fetchFn = vi + .fn() + .mockResolvedValue(new Response('', { status: 404 })); + const res = await probeGithubRepo('a', 'b', { + maxKb: 500, + timeoutMs: 1000, + fetchFn, + }); + expect(res).toEqual({ ok: false, reason: 'repo_not_found' }); + }); + + it('returns repo_private for private repos', async () => { + const fetchFn = vi + .fn() + .mockResolvedValue( + new Response(JSON.stringify({ private: true, size: 1 }), { status: 200 }) + ); + const res = await probeGithubRepo('a', 'b', { + maxKb: 500, + timeoutMs: 1000, + fetchFn, + }); + expect(res).toEqual({ ok: false, reason: 'repo_private' }); + }); + + it('returns repo_too_large when size exceeds maxKb', async () => { + const fetchFn = vi + .fn() + .mockResolvedValue( + new Response(JSON.stringify({ private: false, size: 1000 }), { + status: 200, + }) + ); + const res = await probeGithubRepo('a', 'b', { + maxKb: 500, + timeoutMs: 1000, + fetchFn, + }); + expect(res).toEqual({ ok: false, reason: 'repo_too_large' }); + }); + + it('returns repo_not_found on 5xx / network errors', async () => { + const fetchFn = vi.fn().mockRejectedValue(new Error('network down')); + const res = await probeGithubRepo('a', 'b', { + maxKb: 500, + timeoutMs: 1000, + fetchFn, + }); + expect(res).toEqual({ ok: false, reason: 'repo_not_found' }); + }); +}); + +describe('POST /api/trial/create', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.unstubAllGlobals(); + isTrialsEnabledMock.mockResolvedValue(true); + valuesMock.mockResolvedValue(undefined); + }); + + it('rejects non-JSON body with 400 invalid_url', async () => { + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, 'not json'); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('invalid_url'); + }); + + it('rejects non-GitHub URLs with 400 invalid_url', async () => { + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, { + repoUrl: 'https://gitlab.com/foo/bar', + }); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('invalid_url'); + }); + + it('returns 503 trials_disabled when kill-switch is off', async () => { + isTrialsEnabledMock.mockResolvedValue(false); + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/repo', + }); + expect(res.status).toBe(503); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('trials_disabled'); + }); + + it('returns 503 trials_disabled when TRIAL_CLAIM_TOKEN_SECRET is missing', async () => { + const { app, env } = makeApp(makeEnv({ secret: undefined })); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/repo', + }); + expect(res.status).toBe(503); + }); + + it('returns 404 repo_not_found when GitHub 404s', async () => { + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue(new Response('', { status: 404 })) + ); + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/missing', + }); + expect(res.status).toBe(404); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('repo_not_found'); + }); + + it('returns 403 repo_private when probe reports private', async () => { + vi.stubGlobal( + 'fetch', + vi + .fn() + .mockResolvedValue( + new Response(JSON.stringify({ private: true, size: 1 }), { + status: 200, + }) + ) + ); + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/private', + }); + expect(res.status).toBe(403); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('repo_private'); + }); + + it('returns 413 repo_too_large when probe reports oversized', async () => { + vi.stubGlobal( + 'fetch', + vi + .fn() + .mockResolvedValue( + new Response(JSON.stringify({ private: false, size: 9_999_999 }), { + status: 200, + }) + ) + ); + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/big', + }); + expect(res.status).toBe(413); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('repo_too_large'); + }); + + it('returns 429 cap_exceeded with waitlistResetsAt when counter rejects', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const tryIncrementFn = vi.fn(async () => ({ allowed: false, count: 1500 })); + const { app, env } = makeApp(makeEnv({ tryIncrementFn })); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/repo', + }); + expect(res.status).toBe(429); + const body = (await res.json()) as { + error: string; + waitlistResetsAt?: string; + }; + expect(body.error).toBe('cap_exceeded'); + expect(body.waitlistResetsAt).toMatch(/^\d{4}-\d{2}-01$/); + }); + + it('returns 503 trials_disabled when counter DO throws', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const tryIncrementFn = vi.fn().mockRejectedValue(new Error('DO down')); + const { app, env } = makeApp(makeEnv({ tryIncrementFn })); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/repo', + }); + expect(res.status).toBe(503); + }); + + it('happy path: 201 with Set-Cookie headers and response body', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const { app, env } = makeApp(makeEnv({})); + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/repo', + }); + expect(res.status).toBe(201); + + // Validate response body shape + const body = (await res.json()) as { + trialId: string; + projectId: string; + eventsUrl: string; + expiresAt: number; + }; + expect(body.trialId).toMatch(/^trial_/); + expect(body.projectId).toBe(''); // populated by Track B later + // Events URL must match the actual SSE route (path segment, not query + // param). Frontend clients that rely on the response field would break if + // this drifts from the Hono route shape. + expect(body.eventsUrl).toBe(`/api/trial/${body.trialId}/events`); + expect(body.expiresAt).toBeGreaterThan(Date.now()); + + // Two Set-Cookie headers should be present (fingerprint + claim). + // Hono/Workers uses Headers which exposes multiple Set-Cookie via getSetCookie(). + const setCookie = (res.headers as Headers & { + getSetCookie?: () => string[]; + }).getSetCookie?.(); + expect(setCookie).toBeTruthy(); + expect(setCookie!.length).toBeGreaterThanOrEqual(2); + + // D1 insert should have been called exactly once with the expected shape. + expect(insertMock).toHaveBeenCalledTimes(1); + const inserted = valuesMock.mock.calls[0]?.[0] as Record; + expect(inserted.status).toBe('pending'); + expect(inserted.repoUrl).toBe('https://github.com/alice/repo'); + expect(inserted.repoOwner).toBe('alice'); + expect(inserted.repoName).toBe('repo'); + + // KV mirror: trial record MUST be written to KV so Track B readers + // (SSE /events, /claim) can resolve it by trialId. Regression guard — + // skipping this write caused every SSE connection to 404 (see commit + // history). + const kvPut = env.KV.put as unknown as ReturnType; + const trialKeyPut = kvPut.mock.calls.find( + (call) => typeof call[0] === 'string' && call[0].startsWith('trial:') + ); + expect(trialKeyPut).toBeTruthy(); + const stored = JSON.parse(trialKeyPut![1] as string) as { + trialId: string; + fingerprint: string; + projectId: string; + }; + expect(stored.trialId).toBe(body.trialId); + expect(stored.projectId).toBe(''); // populated by Track B orchestrator later + expect(stored.fingerprint).toMatch(/^[0-9a-f-]{36}$/); + }); + + it('decrements counter slot when D1 insert fails (no burn)', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + valuesMock.mockRejectedValueOnce(new Error('UNIQUE constraint failed')); + const decrementFn = vi.fn().mockResolvedValue(0); + const { app, env } = makeApp(makeEnv({ decrementFn })); + + const res = await callCreate(app, env, { + repoUrl: 'https://github.com/alice/repo', + }); + expect(res.status).toBe(500); + expect(decrementFn).toHaveBeenCalledTimes(1); + }); + + it('dispatches TrialOrchestrator.start() with repo owner/name/canonical URL via waitUntil', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const startFn = vi.fn(async () => {}); + const env = makeEnv({ orchestratorStart: startFn }); + const app = new Hono<{ Bindings: Env }>(); + app.route('/api/trial', createRoutes); + const ctx = makeExecutionCtx(); + + const res = await callCreate( + app, + env, + { repoUrl: 'https://github.com/alice/repo' }, + undefined, + ctx + ); + expect(res.status).toBe(201); + + // waitUntil should have been called at least twice (orchestrator + knowledge). + expect(ctx.waitUntilPromises.length).toBeGreaterThanOrEqual(2); + // Drain waitUntil promises so the orchestrator.start() call is observable. + await Promise.all(ctx.waitUntilPromises.map((p) => p.catch(() => {}))); + + expect(startFn).toHaveBeenCalledTimes(1); + const arg = startFn.mock.calls[0]?.[0] as { + trialId: string; + repoUrl: string; + repoOwner: string; + repoName: string; + }; + expect(arg.trialId).toMatch(/^trial_/); + expect(arg.repoOwner).toBe('alice'); + expect(arg.repoName).toBe('repo'); + expect(arg.repoUrl).toBe('https://github.com/alice/repo'); + }); + + it('still returns 201 when TrialOrchestrator.start() rejects (fire-and-forget)', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const startFn = vi.fn().mockRejectedValue(new Error('DO unavailable')); + const env = makeEnv({ orchestratorStart: startFn }); + const app = new Hono<{ Bindings: Env }>(); + app.route('/api/trial', createRoutes); + const ctx = makeExecutionCtx(); + + const res = await callCreate( + app, + env, + { repoUrl: 'https://github.com/alice/repo' }, + undefined, + ctx + ); + // Response status must remain 201 — orchestrator failure is non-blocking. + expect(res.status).toBe(201); + // Drain waitUntil promises so the rejection is observable but swallowed. + await Promise.all(ctx.waitUntilPromises.map((p) => p.catch(() => {}))); + expect(startFn).toHaveBeenCalledTimes(1); + }); + + it('rate-limits per IP when KV shows the window count is at the limit', async () => { + // Regression guard for the HIGH security-auditor finding: `POST /create` + // must enforce a per-IP rate limit before the kill-switch + DO dispatch. + // We override KV.get so the middleware sees an already-exhausted bucket + // and returns 429 without touching the orchestrator. + vi.stubGlobal('fetch', okGithubFetch()); + const env = makeEnv({}); + const exhaustedWindow = { + count: 999_999, + windowStart: Math.floor(Date.now() / 1000 / 3600) * 3600, + }; + (env.KV.get as unknown as ReturnType).mockImplementation( + async (key: string) => + key.startsWith('ratelimit:trial-create:') ? exhaustedWindow : null + ); + const app = new Hono<{ Bindings: Env }>(); + // Mirror the global error handler from apps/api/src/index.ts so AppError + // (thrown by the rate-limit middleware) serializes to its real statusCode. + const { AppError } = await import('../../../src/middleware/error'); + type AppErrorJson = { error: string; message: string; details?: unknown }; + app.onError((err, c) => { + if (err instanceof AppError) { + return c.json(err.toJSON() as AppErrorJson, err.statusCode as 429); + } + return c.json({ error: 'INTERNAL_ERROR', message: 'err' }, 500); + }); + app.route('/api/trial', createRoutes); + + const res = await app.fetch( + new Request('https://api.test/api/trial/create', { + method: 'POST', + headers: { + 'content-type': 'application/json', + 'CF-Connecting-IP': '203.0.113.7', + }, + body: JSON.stringify({ repoUrl: 'https://github.com/alice/repo' }), + }), + env + ); + expect(res.status).toBe(429); + // DO must NOT have been called when rate limit blocks the request. + expect(insertMock).not.toHaveBeenCalled(); + }); + + it('reuses existing fingerprint UUID from a validly-signed cookie', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const { app, env } = makeApp(makeEnv({})); + + // Mint a fingerprint cookie whose HMAC was produced with the same SECRET + // the route uses to verify it. Only validly-signed cookies are trusted — + // see the "forged cookie" test below. + const { signFingerprint } = await import( + '../../../src/services/trial/cookies' + ); + const existingUuid = '11111111-2222-3333-4444-555555555555'; + const signed = await signFingerprint(existingUuid, SECRET); + const cookie = `sam_trial_fingerprint=${signed}`; + + const res = await callCreate( + app, + env, + { repoUrl: 'https://github.com/alice/repo' }, + cookie + ); + expect(res.status).toBe(201); + + // The inserted row should carry the reused UUID. + const inserted = valuesMock.mock.calls[0]?.[0] as { fingerprint: string }; + expect(inserted.fingerprint).toBe(existingUuid); + }); + + // SECURITY regression: a forged fingerprint cookie (no signature, invalid + // signature, or signature minted with a different secret) MUST NOT be + // trusted. If an attacker learns a victim's fingerprint UUID (from logs, + // a captured cookie, or a prior trial row) and submits `.abc` + // to POST /api/trial/create, the route would — if it only split on `.` — + // overwrite the `trial-by-fingerprint:` KV index to point at + // the attacker's trial. The OAuth hook would then redirect the victim to + // the attacker-chosen repo. See the 2026-04-19 security review HIGH #1. + it('rejects a forged fingerprint cookie and mints a fresh UUID', async () => { + vi.stubGlobal('fetch', okGithubFetch()); + const { app, env } = makeApp(makeEnv({})); + + const victimUuid = '11111111-2222-3333-4444-555555555555'; + // Attacker crafts a cookie with the victim's UUID but an invalid HMAC. + const cookie = `sam_trial_fingerprint=${victimUuid}.invalidSignature`; + + const res = await callCreate( + app, + env, + { repoUrl: 'https://github.com/alice/repo' }, + cookie + ); + expect(res.status).toBe(201); + + // The inserted row MUST NOT reuse the victim's UUID. + const inserted = valuesMock.mock.calls[0]?.[0] as { fingerprint: string }; + expect(inserted.fingerprint).not.toBe(victimUuid); + // And it should be a fresh, well-formed UUID (crypto.randomUUID format). + expect(inserted.fingerprint).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i + ); + }); +}); diff --git a/apps/api/tests/unit/routes/trial-events-format.test.ts b/apps/api/tests/unit/routes/trial-events-format.test.ts new file mode 100644 index 000000000..9d3324129 --- /dev/null +++ b/apps/api/tests/unit/routes/trial-events-format.test.ts @@ -0,0 +1,54 @@ +/** + * Regression tests for `formatSse()` — SSE frame serialization helper. + * + * Two invariants are guarded here: + * + * 1. The frame MUST be an unnamed ("message") SSE event — i.e. no `event:` + * line. Browser EventSource consumers only invoke `onmessage` for the + * default event; named events require per-type `addEventListener()` + * registrations. Emitting named events caused the "zero trial.* events + * on staging" incident (2026-04-19) where bytes arrived on the wire but + * the frontend's `onmessage` handler never fired. See + * `docs/notes/2026-04-19-trial-sse-named-events-postmortem.md`. + * + * 2. Data is JSON-encoded, so embedded newlines in payload strings are + * escaped — an SSE-frame-injection defence. The `TrialEvent` discriminated + * union exposed at the call site additionally narrows the `type` field + * to a safe set of literals. + */ +import { describe, expect, it } from 'vitest'; + +import { formatSse } from '../../../src/routes/trial/events'; + +describe('formatSse()', () => { + it('produces an unnamed SSE frame so EventSource.onmessage fires', () => { + const frame = formatSse({ type: 'trial.ready', at: 123 }); + expect(frame).toBe('data: {"type":"trial.ready","at":123}\n\n'); + // No `event:` line — otherwise the client has to register listeners + // per event type and `onmessage` silently never fires. + expect(frame).not.toContain('event:'); + }); + + it('json-encodes the data payload (newlines inside strings are escaped)', () => { + const frame = formatSse({ msg: 'line1\nline2' }); + // JSON.stringify escapes the embedded newline to \n (literal backslash n), + // preventing the payload from terminating the SSE frame early. + expect(frame).toBe('data: {"msg":"line1\\nline2"}\n\n'); + }); + + it('frame shape is exactly one data line plus the blank terminator', () => { + const frame = formatSse({ + type: 'trial.knowledge', + key: 'description', + value: 'hello', + }); + const lines = frame.split('\n'); + // ['data: {...}', '', ''] + expect(lines).toHaveLength(3); + expect(lines[0]?.startsWith('data: ')).toBe(true); + expect(lines[1]).toBe(''); + expect(lines[2]).toBe(''); + const dataLines = frame.match(/^data: /gm) ?? []; + expect(dataLines).toHaveLength(1); + }); +}); diff --git a/apps/api/tests/unit/routes/trial-events.test.ts b/apps/api/tests/unit/routes/trial-events.test.ts new file mode 100644 index 000000000..39ba4449e --- /dev/null +++ b/apps/api/tests/unit/routes/trial-events.test.ts @@ -0,0 +1,210 @@ +/** + * Unit tests for GET /api/trial/:trialId/events (SSE). + * + * Covers: + * - 404 when no trial record exists + * - 500 when TRIAL_CLAIM_TOKEN_SECRET is unset + * - 401 when fingerprint cookie is missing + * - 401 when fingerprint signature fails + * - 401 when fingerprint UUID doesn't match record.fingerprint + * - 200 + text/event-stream headers on happy path with poll returning + * a terminal event (trial.ready) that closes the stream + */ +import { Hono } from 'hono'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +// Trial-store mock +const { readTrialMock } = vi.hoisted(() => ({ + readTrialMock: vi.fn(), +})); +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrial: readTrialMock, +})); + +// Silence logs +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +import { eventsRoutes } from '../../../src/routes/trial/events'; +import { signFingerprint } from '../../../src/services/trial/cookies'; + +const SECRET = 'test-secret-for-sse-route-32-bytes-long-minimum'; + +function makeApp(): Hono<{ Bindings: Env }> { + const app = new Hono<{ Bindings: Env }>(); + app.onError((err, c) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const e = err as any; + if (typeof e.statusCode === 'number' && typeof e.toJSON === 'function') { + return c.json(e.toJSON(), e.statusCode); + } + return c.json({ error: 'INTERNAL', message: err.message }, 500); + }); + app.route('/api/trial', eventsRoutes); + return app; +} + +/** Build an env with a TRIAL_EVENT_BUS DO stub whose `poll` returns the given events */ +function makeEnvWithDO( + pollResponse: { events: { cursor: number; event: unknown }[]; cursor: number; closed: boolean }, + overrides: Partial = {} +): Env { + const stub = { + fetch: vi.fn(async (url: string | URL) => { + if (String(url).includes('/poll')) { + return new Response(JSON.stringify(pollResponse), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); + } + return new Response('not found', { status: 404 }); + }), + }; + const kvStore = new Map(); + return { + TRIAL_CLAIM_TOKEN_SECRET: SECRET, + TRIAL_SSE_HEARTBEAT_MS: '60000', + TRIAL_SSE_POLL_TIMEOUT_MS: '100', + TRIAL_SSE_MAX_DURATION_MS: '5000', + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => stub), + }, + KV: { + get: vi.fn(async (key: string) => kvStore.get(key) ?? null), + put: vi.fn(async (key: string, value: string) => { kvStore.set(key, value); }), + }, + ...overrides, + } as unknown as Env; +} + +async function getEvents( + app: Hono<{ Bindings: Env }>, + trialId: string, + cookie: string | null, + env: Env +): Promise { + const headers: Record = { + 'CF-Connecting-IP': '203.0.113.1', // test IP for rate limiting + }; + if (cookie) headers['cookie'] = `sam_trial_fingerprint=${encodeURIComponent(cookie)}`; + return app.request( + `/api/trial/${trialId}/events`, + { method: 'GET', headers }, + env + ); +} + +describe('GET /api/trial/:trialId/events — auth + bail-out', () => { + beforeEach(() => { + readTrialMock.mockReset(); + }); + + it('returns 404 when no trial record exists', async () => { + const app = makeApp(); + const env = makeEnvWithDO({ events: [], cursor: 0, closed: false }); + readTrialMock.mockResolvedValueOnce(null); + + const resp = await getEvents(app, 'trial_missing', 'whatever', env); + expect(resp.status).toBe(404); + }); + + it('returns 500 when TRIAL_CLAIM_TOKEN_SECRET is unset', async () => { + const app = makeApp(); + readTrialMock.mockResolvedValueOnce({ + trialId: 't', + fingerprint: 'fp', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + const env = makeEnvWithDO( + { events: [], cursor: 0, closed: false }, + { TRIAL_CLAIM_TOKEN_SECRET: undefined } as Partial + ); + const resp = await getEvents(app, 't', 'anything', env); + expect(resp.status).toBe(500); + }); + + it('returns 401 when fingerprint cookie is missing', async () => { + const app = makeApp(); + readTrialMock.mockResolvedValueOnce({ + trialId: 't', + fingerprint: 'fp', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + const env = makeEnvWithDO({ events: [], cursor: 0, closed: false }); + const resp = await getEvents(app, 't', null, env); + expect(resp.status).toBe(401); + }); + + it('returns 401 when fingerprint signature is invalid', async () => { + const app = makeApp(); + readTrialMock.mockResolvedValueOnce({ + trialId: 't', + fingerprint: 'fp', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + const env = makeEnvWithDO({ events: [], cursor: 0, closed: false }); + const resp = await getEvents(app, 't', 'bad.sig', env); + expect(resp.status).toBe(401); + }); + + it("returns 401 when fingerprint UUID doesn't match record.fingerprint", async () => { + const app = makeApp(); + readTrialMock.mockResolvedValueOnce({ + trialId: 't', + fingerprint: 'fp-expected', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + const signed = await signFingerprint('fp-attacker', SECRET); + const env = makeEnvWithDO({ events: [], cursor: 0, closed: false }); + const resp = await getEvents(app, 't', signed, env); + expect(resp.status).toBe(401); + }); +}); + +describe('GET /api/trial/:trialId/events — happy path', () => { + beforeEach(() => { + readTrialMock.mockReset(); + }); + + it('returns 200 + text/event-stream + streams events then closes on terminal', async () => { + const app = makeApp(); + readTrialMock.mockResolvedValueOnce({ + trialId: 'trial_good', + fingerprint: 'fp-good', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + + const terminalEvent = { type: 'trial.ready', projectId: 'proj', at: Date.now() }; + const env = makeEnvWithDO({ + events: [{ cursor: 1, event: terminalEvent }], + cursor: 1, + closed: true, + }); + + const signed = await signFingerprint('fp-good', SECRET); + const resp = await getEvents(app, 'trial_good', signed, env); + + expect(resp.status).toBe(200); + expect(resp.headers.get('content-type')).toContain('text/event-stream'); + expect(resp.headers.get('cache-control')).toContain('no-cache'); + expect(resp.headers.get('x-accel-buffering')).toBe('no'); + + // Drain the stream and assert the SSE frame shape. + const body = await resp.text(); + expect(body).toContain(': connected'); + // SSE frames are UNNAMED ("message") events so EventSource.onmessage + // fires on the client. See + // docs/notes/2026-04-19-trial-sse-named-events-postmortem.md. + expect(body).not.toContain('event: trial.ready'); + expect(body).toContain('"type":"trial.ready"'); + }); +}); diff --git a/apps/api/tests/unit/routes/trial-status.test.ts b/apps/api/tests/unit/routes/trial-status.test.ts new file mode 100644 index 000000000..da5d47ee5 --- /dev/null +++ b/apps/api/tests/unit/routes/trial-status.test.ts @@ -0,0 +1,124 @@ +/** + * Behavioral tests for GET /api/trial/status. + * + * Verifies: + * - Happy path: enabled=true, remaining = cap - count, resetsAt = next month + * - Cap exhausted: remaining = 0 (clamped), enabled still reflects kill-switch + * - Cap = 0 (unlimited): remaining = Number.MAX_SAFE_INTEGER + * - Kill-switch off: enabled=false + * - DO failure: enabled=false, remaining=0 (fail-closed fallback) + */ +import { Hono } from 'hono'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +const { isTrialsEnabledMock } = vi.hoisted(() => ({ + isTrialsEnabledMock: vi.fn(), +})); + +vi.mock('../../../src/services/trial/kill-switch', () => ({ + isTrialsEnabled: isTrialsEnabledMock, +})); + +const { statusRoutes } = await import('../../../src/routes/trial/status'); + +function makeEnv(options: { + doGet?: (key: string) => Promise<{ monthKey: string; count: number }>; + cap?: string; +}): Env { + return { + TRIAL_MONTHLY_CAP: options.cap, + TRIAL_COUNTER: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => ({ + get: + options.doGet ?? + vi.fn(async (key: string) => ({ monthKey: key, count: 0 })), + })), + }, + } as unknown as Env; +} + +function makeApp(env: Env) { + const app = new Hono<{ Bindings: Env }>(); + app.route('/api/trial', statusRoutes); + return { app, env }; +} + +async function getStatus(app: Hono<{ Bindings: Env }>, env: Env) { + return app.fetch(new Request('https://api.test/api/trial/status'), env); +} + +describe('GET /api/trial/status', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('returns enabled + remaining computed from cap and DO count', async () => { + isTrialsEnabledMock.mockResolvedValue(true); + const doGet = vi.fn().mockResolvedValue({ monthKey: '2026-04', count: 500 }); + const { app, env } = makeApp(makeEnv({ doGet, cap: '1500' })); + + const res = await getStatus(app, env); + expect(res.status).toBe(200); + const body = (await res.json()) as { + enabled: boolean; + remaining: number; + resetsAt: string; + }; + expect(body.enabled).toBe(true); + expect(body.remaining).toBe(1000); + expect(body.resetsAt).toMatch(/^\d{4}-\d{2}-01$/); + }); + + it('clamps remaining to 0 when cap is exhausted', async () => { + isTrialsEnabledMock.mockResolvedValue(true); + const doGet = vi.fn().mockResolvedValue({ monthKey: '2026-04', count: 1500 }); + const { app, env } = makeApp(makeEnv({ doGet, cap: '1500' })); + + const res = await getStatus(app, env); + const body = (await res.json()) as { remaining: number }; + expect(body.remaining).toBe(0); + }); + + it('returns MAX_SAFE_INTEGER when cap=0 (unlimited)', async () => { + isTrialsEnabledMock.mockResolvedValue(true); + const doGet = vi.fn().mockResolvedValue({ monthKey: '2026-04', count: 42 }); + const { app, env } = makeApp(makeEnv({ doGet, cap: '0' })); + + const res = await getStatus(app, env); + const body = (await res.json()) as { remaining: number }; + expect(body.remaining).toBe(Number.MAX_SAFE_INTEGER); + }); + + it('returns enabled=false when kill-switch is off', async () => { + isTrialsEnabledMock.mockResolvedValue(false); + const { app, env } = makeApp(makeEnv({})); + + const res = await getStatus(app, env); + const body = (await res.json()) as { enabled: boolean }; + expect(body.enabled).toBe(false); + }); + + it('fails closed (enabled=false, remaining=0) when DO throws', async () => { + isTrialsEnabledMock.mockResolvedValue(true); + const doGet = vi.fn().mockRejectedValue(new Error('DO unreachable')); + const { app, env } = makeApp(makeEnv({ doGet })); + + const res = await getStatus(app, env); + expect(res.status).toBe(200); + const body = (await res.json()) as { + enabled: boolean; + remaining: number; + resetsAt: string; + }; + expect(body.enabled).toBe(false); + expect(body.remaining).toBe(0); + expect(body.resetsAt).toMatch(/^\d{4}-\d{2}-01$/); + }); +}); diff --git a/apps/api/tests/unit/routes/trial-waitlist.test.ts b/apps/api/tests/unit/routes/trial-waitlist.test.ts new file mode 100644 index 000000000..dd937dd4d --- /dev/null +++ b/apps/api/tests/unit/routes/trial-waitlist.test.ts @@ -0,0 +1,118 @@ +/** + * Behavioral tests for POST /api/trial/waitlist. + * + * Verifies: + * - Valid email → 200 { queued: true, resetsAt } + * - Invalid email → 400 BAD_REQUEST + * - Non-JSON body → 400 BAD_REQUEST + * - Email is lowercased before insert (case-insensitive dedupe) + * - Uses onConflictDoNothing on (email, resetDate) for idempotent re-submits + * - D1 error → 500 INTERNAL_ERROR + */ +import { Hono } from 'hono'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +const { insertMock, valuesMock, onConflictMock } = vi.hoisted(() => { + const onConflictMock = vi.fn().mockResolvedValue(undefined); + const valuesMock = vi + .fn() + .mockReturnValue({ onConflictDoNothing: onConflictMock }); + const insertMock = vi.fn().mockReturnValue({ values: valuesMock }); + return { insertMock, valuesMock, onConflictMock }; +}); + +vi.mock('drizzle-orm/d1', () => ({ + drizzle: vi.fn().mockReturnValue({ + insert: insertMock, + }), +})); + +vi.mock('../../../src/db/schema', () => ({ + trialWaitlist: { + email: 'email', + resetDate: 'reset_date', + }, +})); + +const { waitlistRoutes } = await import('../../../src/routes/trial/waitlist'); + +function makeApp() { + const app = new Hono<{ Bindings: Env }>(); + app.route('/api/trial', waitlistRoutes); + const env = { DATABASE: {} } as unknown as Env; + return { app, env }; +} + +async function post(app: Hono<{ Bindings: Env }>, env: Env, body: unknown) { + return app.fetch( + new Request('https://api.test/api/trial/waitlist', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: typeof body === 'string' ? body : JSON.stringify(body), + }), + env + ); +} + +describe('POST /api/trial/waitlist', () => { + beforeEach(() => { + vi.clearAllMocks(); + onConflictMock.mockResolvedValue(undefined); + }); + + it('queues a valid email and returns resetsAt', async () => { + const { app, env } = makeApp(); + const res = await post(app, env, { email: 'alice@example.com' }); + expect(res.status).toBe(200); + const body = (await res.json()) as { queued: boolean; resetsAt: string }; + expect(body.queued).toBe(true); + expect(body.resetsAt).toMatch(/^\d{4}-\d{2}-01$/); + }); + + it('lowercases the email before insert (case-insensitive dedupe)', async () => { + const { app, env } = makeApp(); + await post(app, env, { email: 'ALICE@Example.COM' }); + expect(valuesMock).toHaveBeenCalledWith( + expect.objectContaining({ email: 'alice@example.com' }) + ); + }); + + it('uses onConflictDoNothing on (email, resetDate) for idempotency', async () => { + const { app, env } = makeApp(); + await post(app, env, { email: 'bob@example.com' }); + expect(onConflictMock).toHaveBeenCalledWith( + expect.objectContaining({ + target: expect.arrayContaining(['email', 'reset_date']), + }) + ); + }); + + it('rejects invalid email with 400', async () => { + const { app, env } = makeApp(); + const res = await post(app, env, { email: 'not-an-email' }); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('BAD_REQUEST'); + }); + + it('rejects non-JSON body with 400', async () => { + const { app, env } = makeApp(); + const res = await post(app, env, 'not json at all'); + expect(res.status).toBe(400); + }); + + it('returns 500 when D1 insert fails', async () => { + onConflictMock.mockRejectedValue(new Error('D1 exploded')); + const { app, env } = makeApp(); + const res = await post(app, env, { email: 'alice@example.com' }); + expect(res.status).toBe(500); + const body = (await res.json()) as { error: string }; + expect(body.error).toBe('INTERNAL_ERROR'); + }); +}); diff --git a/apps/api/tests/unit/scheduled/trial-expire.test.ts b/apps/api/tests/unit/scheduled/trial-expire.test.ts new file mode 100644 index 000000000..7c023621b --- /dev/null +++ b/apps/api/tests/unit/scheduled/trial-expire.test.ts @@ -0,0 +1,115 @@ +/** + * Unit tests for `scheduled/trial-expire.ts`. + * + * Verifies: + * - Rows with status ∈ {pending, ready} AND expires_at < now are selected + * - Selected rows are updated to status='expired' + * - When no candidates exist, no update is issued and `expired=0` + * - Counter DO is NOT called (slot is legitimately consumed) + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +const { mockSelect, mockSelectFrom, mockUpdate, mockUpdateSet, mockUpdateWhere } = + vi.hoisted(() => { + const mockUpdateWhere = vi.fn().mockResolvedValue(undefined); + const mockUpdateSet = vi.fn().mockReturnValue({ where: mockUpdateWhere }); + const mockUpdate = vi.fn().mockReturnValue({ set: mockUpdateSet }); + + const mockSelectFrom = vi.fn(); + const mockSelect = vi.fn().mockReturnValue({ from: mockSelectFrom }); + + return { + mockSelect, + mockSelectFrom, + mockUpdate, + mockUpdateSet, + mockUpdateWhere, + }; + }); + +vi.mock('drizzle-orm/d1', () => ({ + drizzle: vi.fn().mockReturnValue({ + select: mockSelect, + update: mockUpdate, + }), +})); + +vi.mock('drizzle-orm', () => ({ + and: vi.fn((...args: unknown[]) => ({ type: 'and', args })), + inArray: vi.fn((_col: unknown, vals: unknown) => ({ type: 'inArray', vals })), + lt: vi.fn((_col: unknown, val: unknown) => ({ type: 'lt', val })), +})); + +vi.mock('../../../src/db/schema', () => ({ + trials: { + id: 'id', + status: 'status', + expiresAt: 'expires_at', + }, +})); + +const { runTrialExpireSweep } = await import( + '../../../src/scheduled/trial-expire' +); + +function makeEnv(): Env { + return { DATABASE: {} } as unknown as Env; +} + +/** + * Drizzle chain: `.select(...).from(...).where(...).limit(...)`. The final + * `.limit()` resolves to the row array. + */ +function buildSelectChain(rows: unknown[]) { + return { + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue(rows), + }), + }; +} + +describe('runTrialExpireSweep', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('expires rows older than now with status pending/ready', async () => { + const candidates = [{ id: 'trial_a' }, { id: 'trial_b' }, { id: 'trial_c' }]; + mockSelectFrom.mockReturnValueOnce(buildSelectChain(candidates)); + + const res = await runTrialExpireSweep(makeEnv(), 1_700_000_000_000); + + expect(res).toEqual({ expired: 3 }); + // update() was called once + expect(mockUpdate).toHaveBeenCalledTimes(1); + // The set() call should be { status: 'expired' } + expect(mockUpdateSet).toHaveBeenCalledWith({ status: 'expired' }); + // where() should use inArray of the captured ids + const whereArg = mockUpdateWhere.mock.calls[0]?.[0] as { + type: string; + vals: unknown; + }; + expect(whereArg.type).toBe('inArray'); + expect(whereArg.vals).toEqual(['trial_a', 'trial_b', 'trial_c']); + }); + + it('returns early with expired=0 when no candidates exist', async () => { + mockSelectFrom.mockReturnValueOnce(buildSelectChain([])); + + const res = await runTrialExpireSweep(makeEnv(), 1_700_000_000_000); + + expect(res).toEqual({ expired: 0 }); + expect(mockUpdate).not.toHaveBeenCalled(); + }); + + it('does NOT call the TrialCounter DO (slot is consumed legitimately)', async () => { + // Env has no TRIAL_COUNTER binding — this would throw if the code touched it. + const env = makeEnv(); + mockSelectFrom.mockReturnValueOnce(buildSelectChain([{ id: 'trial_x' }])); + await expect(runTrialExpireSweep(env, 1_700_000_000_000)).resolves.toEqual({ + expired: 1, + }); + }); +}); diff --git a/apps/api/tests/unit/scheduled/trial-rollover.test.ts b/apps/api/tests/unit/scheduled/trial-rollover.test.ts new file mode 100644 index 000000000..c8ed4940e --- /dev/null +++ b/apps/api/tests/unit/scheduled/trial-rollover.test.ts @@ -0,0 +1,99 @@ +/** + * Unit tests for `scheduled/trial-rollover.ts`. + * + * Verifies: + * - prune() is called with the month key `keepMonths - 1` earlier than current + * - Returned result carries monthKey + pruned + * - DO failure is swallowed (logged) and returns pruned=0 so the cron continues + * - monthKey drift is logged but does not throw + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +const { mockLogWarn, mockLogError } = vi.hoisted(() => ({ + mockLogWarn: vi.fn(), + mockLogError: vi.fn(), +})); + +vi.mock('../../../src/lib/logger', () => ({ + log: { + info: vi.fn(), + warn: mockLogWarn, + error: mockLogError, + }, +})); + +import { runTrialRolloverAudit } from '../../../src/scheduled/trial-rollover'; + +function makeEnv(options: { + pruneFn: (keep: string) => Promise; + getFn?: (key: string) => Promise<{ monthKey: string; count: number }>; + keepMonths?: string; +}): Env { + return { + TRIAL_COUNTER_KEEP_MONTHS: options.keepMonths, + TRIAL_COUNTER: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => ({ + prune: options.pruneFn, + get: + options.getFn ?? + vi.fn(async (key: string) => ({ monthKey: key, count: 0 })), + })), + }, + } as unknown as Env; +} + +describe('runTrialRolloverAudit', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('prunes rows older than keepMonths and returns pruned count', async () => { + const pruneFn = vi.fn().mockResolvedValue(7); + // now = 2026-04-18 → monthKey 2026-04 → keep 3 months → oldestKept 2026-02 + const now = Date.UTC(2026, 3, 18); + const env = makeEnv({ pruneFn, keepMonths: '3' }); + + const res = await runTrialRolloverAudit(env, now); + expect(res).toEqual({ monthKey: '2026-04', pruned: 7 }); + expect(pruneFn).toHaveBeenCalledWith('2026-02'); + }); + + it('falls back to DEFAULT_TRIAL_COUNTER_KEEP_MONTHS (3) when env is missing', async () => { + const pruneFn = vi.fn().mockResolvedValue(0); + const now = Date.UTC(2026, 5, 1); // 2026-06 + const env = makeEnv({ pruneFn }); + + await runTrialRolloverAudit(env, now); + expect(pruneFn).toHaveBeenCalledWith('2026-04'); + }); + + it('logs a warning on monthKey drift but returns normally', async () => { + const pruneFn = vi.fn().mockResolvedValue(0); + const getFn = vi.fn().mockResolvedValue({ monthKey: '2026-03', count: 1 }); + const now = Date.UTC(2026, 3, 18); // 2026-04 + const env = makeEnv({ pruneFn, getFn }); + + const res = await runTrialRolloverAudit(env, now); + expect(res).toEqual({ monthKey: '2026-04', pruned: 0 }); + expect(mockLogWarn).toHaveBeenCalledWith( + 'trial.rollover.monthKey_drift', + expect.objectContaining({ expected: '2026-04', actual: '2026-03' }) + ); + }); + + it('swallows DO errors and returns pruned=0', async () => { + const pruneFn = vi.fn().mockRejectedValue(new Error('do unavailable')); + const now = Date.UTC(2026, 3, 18); + const env = makeEnv({ pruneFn }); + + const res = await runTrialRolloverAudit(env, now); + expect(res).toEqual({ monthKey: '2026-04', pruned: 0 }); + expect(mockLogError).toHaveBeenCalledWith( + 'trial.rollover.failed', + expect.any(Object) + ); + }); +}); diff --git a/apps/api/tests/unit/scheduled/trial-waitlist-cleanup.test.ts b/apps/api/tests/unit/scheduled/trial-waitlist-cleanup.test.ts new file mode 100644 index 000000000..cc0d12c6f --- /dev/null +++ b/apps/api/tests/unit/scheduled/trial-waitlist-cleanup.test.ts @@ -0,0 +1,112 @@ +/** + * Unit tests for `scheduled/trial-waitlist-cleanup.ts`. + * + * Verifies: + * - Only rows with notified_at IS NOT NULL AND notified_at < (now - purgeDays) are candidates + * - When there are candidates, a DELETE is issued + * - When there are no candidates, no DELETE is issued and `purged=0` + * - purgeDays falls back to the default when env is missing + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +const { mockSelect, mockSelectFrom, mockDelete, mockLt } = vi.hoisted(() => { + const mockDeleteWhere = vi.fn().mockResolvedValue(undefined); + const mockDelete = vi.fn().mockReturnValue({ where: mockDeleteWhere }); + const mockSelectFrom = vi.fn(); + const mockSelect = vi.fn().mockReturnValue({ from: mockSelectFrom }); + const mockLt = vi.fn((_col: unknown, val: unknown) => ({ type: 'lt', val })); + return { + mockSelect, + mockSelectFrom, + mockDelete, + mockLt, + }; +}); + +vi.mock('drizzle-orm/d1', () => ({ + drizzle: vi.fn().mockReturnValue({ + select: mockSelect, + delete: mockDelete, + }), +})); + +vi.mock('drizzle-orm', () => ({ + and: vi.fn((...args: unknown[]) => ({ type: 'and', args })), + isNotNull: vi.fn((col: unknown) => ({ type: 'isNotNull', col })), + lt: mockLt, +})); + +vi.mock('../../../src/db/schema', () => ({ + trialWaitlist: { + id: 'id', + notifiedAt: 'notified_at', + }, +})); + +const { runTrialWaitlistCleanup } = await import( + '../../../src/scheduled/trial-waitlist-cleanup' +); + +function makeEnv(overrides: Partial = {}): Env { + return { DATABASE: {}, ...overrides } as unknown as Env; +} + +function buildSelectChain(rows: unknown[]) { + return { + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue(rows), + }), + }; +} + +describe('runTrialWaitlistCleanup', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('deletes rows when candidates exist and returns the count', async () => { + const candidates = [{ id: 'wl_a' }, { id: 'wl_b' }]; + mockSelectFrom.mockReturnValueOnce(buildSelectChain(candidates)); + + const res = await runTrialWaitlistCleanup(makeEnv(), 1_700_000_000_000); + expect(res).toEqual({ purged: 2 }); + expect(mockDelete).toHaveBeenCalledTimes(1); + }); + + it('returns purged=0 without issuing a DELETE when nothing matches', async () => { + mockSelectFrom.mockReturnValueOnce(buildSelectChain([])); + const res = await runTrialWaitlistCleanup(makeEnv(), 1_700_000_000_000); + expect(res).toEqual({ purged: 0 }); + expect(mockDelete).not.toHaveBeenCalled(); + }); + + it('uses the env override TRIAL_WAITLIST_PURGE_DAYS when present', async () => { + mockSelectFrom.mockReturnValueOnce(buildSelectChain([])); + const env = makeEnv({ TRIAL_WAITLIST_PURGE_DAYS: '60' }); + const now = 2_000_000_000_000; // arbitrary + await runTrialWaitlistCleanup(env, now); + + // The `lt` helper was called with the threshold — verify it was derived + // from the env override (60 days) not the default (30 days). + const expectedThreshold60 = now - 60 * 24 * 60 * 60 * 1000; + expect(mockLt).toHaveBeenCalledWith( + expect.anything(), + expectedThreshold60 + ); + }); + + it('falls back to default 30 days when env is missing', async () => { + mockSelectFrom.mockReturnValueOnce(buildSelectChain([])); + const env = makeEnv(); + const now = 2_000_000_000_000; + await runTrialWaitlistCleanup(env, now); + + const expectedThreshold30 = now - 30 * 24 * 60 * 60 * 1000; + expect(mockLt).toHaveBeenCalledWith( + expect.anything(), + expectedThreshold30 + ); + }); +}); diff --git a/apps/api/tests/unit/services/ai-anthropic-translate.test.ts b/apps/api/tests/unit/services/ai-anthropic-translate.test.ts new file mode 100644 index 000000000..4583d8d3f --- /dev/null +++ b/apps/api/tests/unit/services/ai-anthropic-translate.test.ts @@ -0,0 +1,247 @@ +/** + * Unit tests for the OpenAI ↔ Anthropic format translation layer. + */ +import { describe, expect, it } from 'vitest'; + +import { + translateRequestToAnthropic, + translateResponseToOpenAI, +} from '../../../src/services/ai-anthropic-translate'; + +describe('translateRequestToAnthropic', () => { + it('translates a simple user message', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [{ role: 'user', content: 'Hello' }], + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + expect(result.model).toBe('claude-haiku-4-5-20251001'); + expect(result.messages).toHaveLength(1); + expect(result.messages[0].role).toBe('user'); + expect(result.messages[0].content).toEqual([{ type: 'text', text: 'Hello' }]); + expect(result.max_tokens).toBe(4096); + }); + + it('extracts system messages to top-level system field', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [ + { role: 'system', content: 'You are helpful.' }, + { role: 'user', content: 'Hi' }, + ], + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + expect(result.system).toBe('You are helpful.'); + expect(result.messages).toHaveLength(1); + expect(result.messages[0].role).toBe('user'); + }); + + it('joins multiple system messages', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [ + { role: 'system', content: 'Rule 1' }, + { role: 'system', content: 'Rule 2' }, + { role: 'user', content: 'Hi' }, + ], + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + expect(result.system).toBe('Rule 1\n\nRule 2'); + }); + + it('translates assistant messages with tool calls', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [ + { role: 'user', content: 'What time is it?' }, + { + role: 'assistant', + content: 'Let me check.', + tool_calls: [{ + id: 'call_123', + type: 'function', + function: { name: 'get_time', arguments: '{}' }, + }], + }, + { + role: 'tool', + tool_call_id: 'call_123', + content: '14:30 UTC', + }, + ], + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + // Assistant message with text + tool_use + expect(result.messages[1].role).toBe('assistant'); + expect(result.messages[1].content).toHaveLength(2); + expect(result.messages[1].content[0]).toEqual({ type: 'text', text: 'Let me check.' }); + expect(result.messages[1].content[1]).toEqual({ + type: 'tool_use', + id: 'call_123', + name: 'get_time', + input: {}, + }); + + // Tool result in a user message + expect(result.messages[2].role).toBe('user'); + expect(result.messages[2].content[0]).toEqual({ + type: 'tool_result', + tool_use_id: 'call_123', + content: '14:30 UTC', + }); + }); + + it('translates OpenAI tools to Anthropic tools format', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [{ role: 'user', content: 'Hi' }], + tools: [{ + type: 'function', + function: { + name: 'get_weather', + description: 'Get the weather', + parameters: { type: 'object', properties: { city: { type: 'string' } } }, + }, + }], + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + expect(result.tools).toHaveLength(1); + expect(result.tools![0]).toEqual({ + name: 'get_weather', + description: 'Get the weather', + input_schema: { type: 'object', properties: { city: { type: 'string' } } }, + }); + }); + + it('passes through temperature and top_p', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [{ role: 'user', content: 'Hi' }], + temperature: 0.7, + top_p: 0.9, + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + expect(result.temperature).toBe(0.7); + expect(result.top_p).toBe(0.9); + }); + + it('uses max_tokens from request body', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [{ role: 'user', content: 'Hi' }], + max_tokens: 1024, + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + expect(result.max_tokens).toBe(1024); + }); + + it('merges consecutive same-role messages', () => { + const body = { + model: 'claude-haiku-4-5-20251001', + messages: [ + { role: 'user', content: 'Hello' }, + { role: 'user', content: 'Are you there?' }, + ], + }; + + const result = translateRequestToAnthropic(body, 'claude-haiku-4-5-20251001'); + + // Should be merged into one user message + expect(result.messages).toHaveLength(1); + expect(result.messages[0].content).toHaveLength(2); + }); +}); + +describe('translateResponseToOpenAI', () => { + it('translates a simple text response', () => { + const response = { + id: 'msg_123', + type: 'message' as const, + role: 'assistant' as const, + content: [{ type: 'text' as const, text: 'Hello!' }], + model: 'claude-haiku-4-5-20251001', + stop_reason: 'end_turn', + usage: { input_tokens: 10, output_tokens: 5 }, + }; + + const result = translateResponseToOpenAI(response); + + expect(result.id).toBe('chatcmpl-msg_123'); + expect(result.object).toBe('chat.completion'); + expect(result.model).toBe('claude-haiku-4-5-20251001'); + const choices = result.choices as Array>; + expect(choices[0].finish_reason).toBe('stop'); + const message = choices[0].message as Record; + expect(message.role).toBe('assistant'); + expect(message.content).toBe('Hello!'); + const usage = result.usage as Record; + expect(usage.prompt_tokens).toBe(10); + expect(usage.completion_tokens).toBe(5); + expect(usage.total_tokens).toBe(15); + }); + + it('translates tool_use blocks into tool_calls', () => { + const response = { + id: 'msg_456', + type: 'message' as const, + role: 'assistant' as const, + content: [ + { type: 'text' as const, text: 'Let me check.' }, + { + type: 'tool_use' as const, + id: 'toolu_123', + name: 'get_time', + input: { timezone: 'UTC' }, + }, + ], + model: 'claude-haiku-4-5-20251001', + stop_reason: 'tool_use', + usage: { input_tokens: 20, output_tokens: 15 }, + }; + + const result = translateResponseToOpenAI(response); + + const choices = result.choices as Array>; + expect(choices[0].finish_reason).toBe('tool_calls'); + const message = choices[0].message as Record; + expect(message.content).toBe('Let me check.'); + const toolCalls = message.tool_calls as Array>; + expect(toolCalls).toHaveLength(1); + expect(toolCalls[0].id).toBe('toolu_123'); + expect(toolCalls[0].type).toBe('function'); + const fn = toolCalls[0].function as Record; + expect(fn.name).toBe('get_time'); + expect(fn.arguments).toBe('{"timezone":"UTC"}'); + }); + + it('maps max_tokens stop reason to length', () => { + const response = { + id: 'msg_789', + type: 'message' as const, + role: 'assistant' as const, + content: [{ type: 'text' as const, text: 'Truncated...' }], + model: 'claude-haiku-4-5-20251001', + stop_reason: 'max_tokens', + usage: { input_tokens: 100, output_tokens: 4096 }, + }; + + const result = translateResponseToOpenAI(response); + + const choices = result.choices as Array>; + expect(choices[0].finish_reason).toBe('length'); + }); +}); diff --git a/apps/api/tests/unit/services/trial-bridge.test.ts b/apps/api/tests/unit/services/trial-bridge.test.ts new file mode 100644 index 000000000..8395bb93a --- /dev/null +++ b/apps/api/tests/unit/services/trial-bridge.test.ts @@ -0,0 +1,153 @@ +/** + * Unit tests for the trial event bridge helpers. + * + * Covers: + * - bridgeAcpSessionTransition: emits trial.ready on 'running' transition, + * trial.error on 'failed', no-op on non-trial projects or other states. + * - bridgeKnowledgeAdded / bridgeIdeaCreated: no-op on non-trial projects, + * emit the correct event shape when a trial record is found. + * - All bridges swallow errors from the emitter. + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +const { readTrialByProjectMock, emitTrialEventForProjectMock } = vi.hoisted(() => ({ + readTrialByProjectMock: vi.fn(), + emitTrialEventForProjectMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrialByProject: readTrialByProjectMock, +})); +vi.mock('../../../src/services/trial/trial-runner', () => ({ + emitTrialEventForProject: emitTrialEventForProjectMock, +})); + +const bridge = await import('../../../src/services/trial/bridge'); + +function makeEnv(): Env { + return { BASE_DOMAIN: 'example.com' } as unknown as Env; +} + +describe('bridgeAcpSessionTransition', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('no-ops on non-trial projects (readTrialByProject returns null)', async () => { + readTrialByProjectMock.mockResolvedValueOnce(null); + await bridge.bridgeAcpSessionTransition(makeEnv(), 'proj_nope', 'running'); + expect(emitTrialEventForProjectMock).not.toHaveBeenCalled(); + }); + + it('emits trial.ready with workspaceUrl on running transition', async () => { + readTrialByProjectMock.mockResolvedValueOnce({ + trialId: 'trial_1', + projectId: 'proj_1', + workspaceId: 'ws_1', + }); + await bridge.bridgeAcpSessionTransition(makeEnv(), 'proj_1', 'running'); + expect(emitTrialEventForProjectMock).toHaveBeenCalledTimes(1); + const [, , event] = emitTrialEventForProjectMock.mock.calls[0]; + expect(event.type).toBe('trial.ready'); + expect(event.workspaceUrl).toBe('https://ws-ws_1.example.com'); + expect(event.trialId).toBe('trial_1'); + }); + + it('emits trial.error on failed transition', async () => { + readTrialByProjectMock.mockResolvedValueOnce({ + trialId: 'trial_fail', + projectId: 'proj_fail', + workspaceId: null, + }); + await bridge.bridgeAcpSessionTransition(makeEnv(), 'proj_fail', 'failed', { + errorMessage: 'agent crashed', + }); + expect(emitTrialEventForProjectMock).toHaveBeenCalledTimes(1); + const [, , event] = emitTrialEventForProjectMock.mock.calls[0]; + expect(event.type).toBe('trial.error'); + expect(event.message).toBe('agent crashed'); + }); + + it('no-ops on unrelated transitions (e.g. pending → assigned)', async () => { + readTrialByProjectMock.mockResolvedValueOnce({ + trialId: 'trial_x', + projectId: 'proj_x', + workspaceId: null, + }); + await bridge.bridgeAcpSessionTransition(makeEnv(), 'proj_x', 'assigned'); + expect(emitTrialEventForProjectMock).not.toHaveBeenCalled(); + }); + + it('swallows errors from the emitter', async () => { + readTrialByProjectMock.mockRejectedValueOnce(new Error('KV down')); + await expect( + bridge.bridgeAcpSessionTransition(makeEnv(), 'proj_1', 'running') + ).resolves.toBeUndefined(); + }); +}); + +describe('bridgeKnowledgeAdded / bridgeIdeaCreated', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('no-ops on non-trial projects', async () => { + readTrialByProjectMock.mockResolvedValue(null); + await bridge.bridgeKnowledgeAdded(makeEnv(), 'proj_nope', 'repo', 'obs'); + await bridge.bridgeIdeaCreated(makeEnv(), 'proj_nope', 'idea_1', 't', 's'); + expect(emitTrialEventForProjectMock).not.toHaveBeenCalled(); + }); + + it('emits trial.knowledge when a trial record exists', async () => { + readTrialByProjectMock.mockResolvedValue({ + trialId: 'trial_k', + projectId: 'proj_k', + workspaceId: null, + }); + await bridge.bridgeKnowledgeAdded( + makeEnv(), + 'proj_k', + 'repository', + 'uses TypeScript' + ); + const [, , event] = emitTrialEventForProjectMock.mock.calls[0]; + expect(event.type).toBe('trial.knowledge'); + expect(event.entity).toBe('repository'); + expect(event.observation).toBe('uses TypeScript'); + }); + + it('emits trial.idea when a trial record exists', async () => { + readTrialByProjectMock.mockResolvedValue({ + trialId: 'trial_i', + projectId: 'proj_i', + workspaceId: null, + }); + await bridge.bridgeIdeaCreated( + makeEnv(), + 'proj_i', + 'idea_42', + 'Add CLI', + 'Short summary' + ); + const [, , event] = emitTrialEventForProjectMock.mock.calls[0]; + expect(event.type).toBe('trial.idea'); + expect(event.ideaId).toBe('idea_42'); + expect(event.title).toBe('Add CLI'); + expect(event.summary).toBe('Short summary'); + }); + + it('swallows errors from the emitter', async () => { + readTrialByProjectMock.mockRejectedValue(new Error('KV blew up')); + await expect( + bridge.bridgeKnowledgeAdded(makeEnv(), 'proj_e', 'e', 'o') + ).resolves.toBeUndefined(); + await expect( + bridge.bridgeIdeaCreated(makeEnv(), 'proj_e', 'i', 't', 's') + ).resolves.toBeUndefined(); + }); +}); diff --git a/apps/api/tests/unit/services/trial-cookies.test.ts b/apps/api/tests/unit/services/trial-cookies.test.ts new file mode 100644 index 000000000..3f6d65525 --- /dev/null +++ b/apps/api/tests/unit/services/trial-cookies.test.ts @@ -0,0 +1,314 @@ +/** + * Unit tests for trial cookie signing / verification. + * + * Covers: + * - Fingerprint sign → verify round trip + tamper rejection + * - Claim token sign → verify round trip + * - Tampered signature → `bad_signature` + * - Mutated payload → `bad_signature` (sig no longer matches) + * - Malformed token (no dot, non-JSON body, missing fields) → `malformed` + * - Expired token → `expired` + * - Constant-time comparison returns for equal-length and unequal-length strings + * - Cookie string builders include HttpOnly / Secure / SameSite / Max-Age + */ +import { describe, expect, it } from 'vitest'; + +import { + buildClaimCookie, + buildFingerprintCookie, + clearClaimCookie, + DEFAULT_TRIAL_CLAIM_TTL_MS, + DEFAULT_TRIAL_FINGERPRINT_TTL_SEC, + signClaimToken, + signFingerprint, + type TrialClaimPayload, + verifyClaimToken, + verifyFingerprint, +} from '../../../src/services/trial/cookies'; + +const SECRET = 'test-secret-at-least-32-bytes-long-for-hmac'; + +// --------------------------------------------------------------------------- +// Fingerprint +// --------------------------------------------------------------------------- + +describe('trial cookies — fingerprint', () => { + it('signs and verifies a UUID round trip', async () => { + const uuid = '11111111-2222-3333-4444-555555555555'; + const signed = await signFingerprint(uuid, SECRET); + expect(signed).toContain('.'); + const recovered = await verifyFingerprint(signed, SECRET); + expect(recovered).toBe(uuid); + }); + + it('rejects a fingerprint signed with a different secret', async () => { + const signed = await signFingerprint('u1', SECRET); + const result = await verifyFingerprint(signed, `${SECRET}-other`); + expect(result).toBeNull(); + }); + + it('rejects a tampered UUID body', async () => { + const signed = await signFingerprint('uuid-a', SECRET); + const [, sig] = signed.split('.'); + const tampered = `uuid-b.${sig}`; + expect(await verifyFingerprint(tampered, SECRET)).toBeNull(); + }); + + it('rejects a tampered signature', async () => { + const signed = await signFingerprint('uuid-a', SECRET); + const [body] = signed.split('.'); + const tampered = `${body}.AAAAAAAAAAAA`; + expect(await verifyFingerprint(tampered, SECRET)).toBeNull(); + }); + + it('rejects a value with no dot', async () => { + expect(await verifyFingerprint('no-dot-here', SECRET)).toBeNull(); + }); + + it('rejects a value with empty signature', async () => { + expect(await verifyFingerprint('uuid.', SECRET)).toBeNull(); + }); + + it('rejects a value with empty body', async () => { + expect(await verifyFingerprint('.sig', SECRET)).toBeNull(); + }); +}); + +// --------------------------------------------------------------------------- +// Claim token +// --------------------------------------------------------------------------- + +describe('trial cookies — claim token', () => { + const basePayload: TrialClaimPayload = { + trialId: 'trial_abc', + projectId: 'proj_xyz', + issuedAt: 1_000_000, + expiresAt: 2_000_000, + }; + + it('signs and verifies a claim payload round trip', async () => { + const token = await signClaimToken(basePayload, SECRET); + const result = await verifyClaimToken(token, SECRET, 1_500_000); + expect(result.ok).toBe(true); + if (result.ok) { + expect(result.payload).toEqual(basePayload); + } + }); + + it('rejects a claim verified with the wrong secret', async () => { + const token = await signClaimToken(basePayload, SECRET); + const result = await verifyClaimToken(token, `${SECRET}-other`, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'bad_signature' }); + }); + + it('rejects a tampered signature as bad_signature', async () => { + const token = await signClaimToken(basePayload, SECRET); + const [body] = token.split('.'); + const tampered = `${body}.AAAAAAAAAAAA`; + const result = await verifyClaimToken(tampered, SECRET, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'bad_signature' }); + }); + + it('rejects a mutated body (sig no longer matches) as bad_signature', async () => { + const token = await signClaimToken(basePayload, SECRET); + const [, sig] = token.split('.'); + // Swap in a different but validly-base64url-encoded body; signature won't match. + const otherToken = await signClaimToken( + { ...basePayload, trialId: 'trial_different' }, + SECRET + ); + const [otherBody] = otherToken.split('.'); + const mutated = `${otherBody}.${sig}`; + const result = await verifyClaimToken(mutated, SECRET, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'bad_signature' }); + }); + + it('reports malformed for token without a dot', async () => { + const result = await verifyClaimToken('nodothere', SECRET, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'malformed' }); + }); + + it('reports malformed for token with empty signature', async () => { + const result = await verifyClaimToken('body.', SECRET, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'malformed' }); + }); + + it('reports malformed for token with empty body', async () => { + const result = await verifyClaimToken('.sig', SECRET, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'malformed' }); + }); + + it('reports malformed when the body decodes to non-JSON', async () => { + // base64url("not-json") = "bm90LWpzb24" + const badBody = 'bm90LWpzb24'; + // Sign the bad body with the real secret so signature passes, then we hit + // the JSON.parse branch. + const realToken = await signClaimToken(basePayload, SECRET); + const [, realSig] = realToken.split('.'); + // Since we need a matching signature, manually sign the bad body. + // We do this by signing a body that decodes to non-JSON then using its sig. + // Simpler approach: import a fresh sign by calling signClaimToken with a + // payload we can't stringify — not worth it. Instead, assert that a body + // that is NOT valid base64url JSON falls through to the parse catch. + // Our signClaimToken always encodes JSON, so we need to construct the token + // manually: body = non-JSON base64url + matching HMAC(secret, body). + const crypto_ = globalThis.crypto; + const key = await crypto_.subtle.importKey( + 'raw', + new TextEncoder().encode(SECRET), + { name: 'HMAC', hash: 'SHA-256' }, + false, + ['sign'] + ); + const sigBytes = new Uint8Array( + await crypto_.subtle.sign('HMAC', key, new TextEncoder().encode(badBody)) + ); + let bin = ''; + for (let i = 0; i < sigBytes.length; i++) bin += String.fromCharCode(sigBytes[i]!); + const sig = btoa(bin).replace(/\+/g, '-').replace(/\//g, '_').replace(/=+$/, ''); + const forged = `${badBody}.${sig}`; + + const result = await verifyClaimToken(forged, SECRET, 1_500_000); + expect(result).toEqual({ ok: false, reason: 'malformed' }); + + // Silence the "unused variable" lint by referencing realSig. + expect(realSig).toBeTypeOf('string'); + }); + + it('reports malformed when required fields are missing', async () => { + // Sign a payload missing projectId. + const badPayload = { trialId: 't', issuedAt: 1, expiresAt: 2 } as unknown as TrialClaimPayload; + const token = await signClaimToken(badPayload, SECRET); + const result = await verifyClaimToken(token, SECRET, 0); + expect(result).toEqual({ ok: false, reason: 'malformed' }); + }); + + it('reports expired when now >= expiresAt', async () => { + const token = await signClaimToken(basePayload, SECRET); + // now = expiresAt — boundary must reject. + const result = await verifyClaimToken(token, SECRET, basePayload.expiresAt); + expect(result).toEqual({ ok: false, reason: 'expired' }); + }); + + it('reports expired when now is well past expiresAt', async () => { + const token = await signClaimToken(basePayload, SECRET); + const result = await verifyClaimToken(token, SECRET, basePayload.expiresAt + 1_000_000); + expect(result).toEqual({ ok: false, reason: 'expired' }); + }); + + it('accepts when now is just before expiresAt', async () => { + const token = await signClaimToken(basePayload, SECRET); + const result = await verifyClaimToken(token, SECRET, basePayload.expiresAt - 1); + expect(result.ok).toBe(true); + }); +}); + +// --------------------------------------------------------------------------- +// Cookie string builders +// --------------------------------------------------------------------------- + +describe('trial cookies — cookie string builders', () => { + it('builds a fingerprint cookie with secure, httponly, samesite=lax', () => { + const c = buildFingerprintCookie('uuid.sig'); + expect(c).toContain('sam_trial_fingerprint=uuid.sig'); + expect(c).toContain('Path=/'); + expect(c).toContain('HttpOnly'); + expect(c).toContain('SameSite=Lax'); + expect(c).toContain('Secure'); + expect(c).toContain(`Max-Age=${DEFAULT_TRIAL_FINGERPRINT_TTL_SEC}`); + }); + + it('supports suppressing Secure for local dev', () => { + const c = buildFingerprintCookie('x.y', { secure: false }); + expect(c).not.toContain('Secure'); + }); + + it('adds Domain when provided', () => { + const c = buildFingerprintCookie('x.y', { domain: 'example.com' }); + expect(c).toContain('Domain=example.com'); + }); + + it('builds a claim cookie with 48h default Max-Age', () => { + const c = buildClaimCookie('tok'); + expect(c).toContain('sam_trial_claim=tok'); + expect(c).toContain(`Max-Age=${Math.floor(DEFAULT_TRIAL_CLAIM_TTL_MS / 1000)}`); + expect(c).toContain('HttpOnly'); + }); + + it('clearClaimCookie sets Max-Age=0', () => { + const c = clearClaimCookie(); + expect(c).toContain('sam_trial_claim=;'); + expect(c).toContain('Max-Age=0'); + expect(c).toContain('Secure'); + }); +}); + +// --------------------------------------------------------------------------- +// Cookie domain consistency (regression test for CRITICAL cookie domain bug) +// --------------------------------------------------------------------------- +// The browser treats cookies with different Domain attributes as distinct. +// If create.ts sets `Domain=.example.com` but claim.ts clears without a Domain, +// the original cookie is never deleted — enabling replay attacks. This test +// asserts the invariant that all three cookie call sites produce matching Domain +// attributes. +// See: clearClaimCookie domain mismatch fix in claim.ts / oauth-hook.ts. + +describe('trial cookies — domain consistency invariant', () => { + const domain = '.example.com'; + + it('buildClaimCookie includes Domain when provided', () => { + const c = buildClaimCookie('tok', { domain }); + expect(c).toContain(`Domain=${domain}`); + }); + + it('clearClaimCookie includes the same Domain when provided', () => { + const c = clearClaimCookie({ domain }); + expect(c).toContain(`Domain=${domain}`); + }); + + it('buildFingerprintCookie includes Domain when provided', () => { + const c = buildFingerprintCookie('uuid.sig', { domain }); + expect(c).toContain(`Domain=${domain}`); + }); + + it('clearClaimCookie WITHOUT domain does NOT match a domain-scoped cookie', () => { + // This is the exact bug: if you set with Domain=.example.com but clear + // without Domain, the browser sees them as different cookies. + const setCookie = buildClaimCookie('tok', { domain }); + const clearCookie = clearClaimCookie(); // no domain — the old buggy path + + // Extract Domain= from each + const setDomain = setCookie.match(/Domain=[^;]+/)?.[0]; + const clearDomain = clearCookie.match(/Domain=[^;]+/)?.[0]; + + expect(setDomain).toBe(`Domain=${domain}`); + expect(clearDomain).toBeUndefined(); // no Domain in the clear cookie + // Therefore these are DIFFERENT cookies from the browser's perspective. + // The fix: always pass the same domain to clearClaimCookie. + expect(setDomain).not.toBe(clearDomain); + }); + + it('clearClaimCookie WITH domain matches the set cookie domain', () => { + // This is the fixed path: both set and clear use the same domain. + const setCookie = buildClaimCookie('tok', { domain }); + const clearCookie = clearClaimCookie({ domain }); + + const setDomain = setCookie.match(/Domain=[^;]+/)?.[0]; + const clearDomain = clearCookie.match(/Domain=[^;]+/)?.[0]; + + expect(setDomain).toBe(`Domain=${domain}`); + expect(clearDomain).toBe(`Domain=${domain}`); + // Same Domain → browser treats them as the same cookie → clear works. + }); + + it('all three builders omit Domain when no domain is provided', () => { + // Without a domain, all cookies are host-only — consistent by omission. + const claim = buildClaimCookie('tok'); + const clear = clearClaimCookie(); + const fp = buildFingerprintCookie('uuid.sig'); + + expect(claim).not.toContain('Domain='); + expect(clear).not.toContain('Domain='); + expect(fp).not.toContain('Domain='); + }); +}); diff --git a/apps/api/tests/unit/services/trial-github-knowledge.test.ts b/apps/api/tests/unit/services/trial-github-knowledge.test.ts new file mode 100644 index 000000000..3981b8887 --- /dev/null +++ b/apps/api/tests/unit/services/trial-github-knowledge.test.ts @@ -0,0 +1,188 @@ +/** + * Behavioral tests for emitGithubKnowledgeEvents. + * + * Covers: + * - Happy path: emits events for description, language, stars, topics, + * license, language breakdown, and README first paragraph. + * - Per-request timeout: a hanging response aborts without blocking others. + * - Max-events cap: never emits more than `TRIAL_KNOWLEDGE_MAX_EVENTS`. + * - Errors are swallowed: a total-network-failure produces 0 events without + * throwing. + * - env.TRIAL_EVENT_BUS errors don't bubble (bridge already covers this + * but the helper's emit wrapper is the last line of defense). + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; + +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +const { emitTrialEventMock } = vi.hoisted(() => ({ + emitTrialEventMock: vi.fn(async () => {}), +})); +vi.mock('../../../src/services/trial/trial-runner', () => ({ + emitTrialEvent: emitTrialEventMock, +})); + +const { emitGithubKnowledgeEvents } = await import( + '../../../src/services/trial/github-knowledge' +); + +function makeEnv(overrides: Partial = {}): Env { + return { + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'stub-id'), + get: vi.fn(() => ({ fetch: vi.fn(async () => new Response('ok')) })), + }, + ...overrides, + } as unknown as Env; +} + +function jsonResp(body: unknown): Response { + return new Response(JSON.stringify(body), { + status: 200, + headers: { 'content-type': 'application/json' }, + }); +} + +function textResp(body: string): Response { + return new Response(body, { status: 200 }); +} + +/** + * Build a fetch mock that returns different responses per URL pattern. + * Matches on path suffix to support both `/repos/:o/:n` and `/repos/:o/:n/languages`. + */ +function routedFetch(routes: { + repo?: () => Response | Promise; + languages?: () => Response | Promise; + readme?: () => Response | Promise; +}): ReturnType { + return vi.fn(async (url: string) => { + if (url.endsWith('/languages') && routes.languages) return routes.languages(); + if (url.endsWith('/readme') && routes.readme) return routes.readme(); + if (routes.repo) return routes.repo(); + return new Response('', { status: 404 }); + }); +} + +describe('emitGithubKnowledgeEvents', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('emits description, language, stars, topics, license, language breakdown, and README', async () => { + const fetchFn = routedFetch({ + repo: () => + jsonResp({ + description: 'A delightful CLI for widgets', + language: 'TypeScript', + stargazers_count: 1234, + topics: ['cli', 'widgets', 'typescript'], + license: { spdx_id: 'MIT' }, + }), + languages: () => + jsonResp({ TypeScript: 9000, JavaScript: 500, CSS: 100 }), + readme: () => + textResp( + '# Widgets\n\nA delightful CLI for widgets — does amazing things when you need them most.\n\n## Installation\n\nRun npm install.' + ), + }); + + await emitGithubKnowledgeEvents( + makeEnv(), + 'trial_abc', + { owner: 'alice', name: 'widgets' }, + { fetchFn: fetchFn as unknown as typeof fetch } + ); + + const calls = emitTrialEventMock.mock.calls; + const observations = calls.map( + (c) => (c[2] as { observation: string }).observation + ); + expect(observations.some((o) => o.includes('Description:'))).toBe(true); + expect(observations.some((o) => o.includes('Primary language: TypeScript'))).toBe( + true + ); + expect(observations.some((o) => o.includes('Stars: 1234'))).toBe(true); + expect(observations.some((o) => o.includes('Topics:'))).toBe(true); + expect(observations.some((o) => o.includes('License: MIT'))).toBe(true); + expect(observations.some((o) => o.includes('Languages (by bytes):'))).toBe( + true + ); + expect(observations.some((o) => o.startsWith('README:'))).toBe(true); + }); + + it('caps total events at TRIAL_KNOWLEDGE_MAX_EVENTS', async () => { + const fetchFn = routedFetch({ + repo: () => + jsonResp({ + description: 'desc', + language: 'Go', + stargazers_count: 1, + topics: ['a', 'b'], + license: { spdx_id: 'Apache-2.0' }, + }), + languages: () => jsonResp({ Go: 100, Shell: 50 }), + readme: () => textResp('A decent description paragraph that exceeds twenty characters by enough.'), + }); + // Cap at 2 — must NOT emit more than 2 events even though the probes + // would naturally produce ~7. + const env = makeEnv({ TRIAL_KNOWLEDGE_MAX_EVENTS: '2' }); + await emitGithubKnowledgeEvents( + env, + 'trial_xyz', + { owner: 'a', name: 'b' }, + { fetchFn: fetchFn as unknown as typeof fetch } + ); + expect(emitTrialEventMock).toHaveBeenCalledTimes(2); + }); + + it('swallows total-network-failure and emits 0 events', async () => { + const fetchFn = vi.fn().mockRejectedValue(new Error('ECONNREFUSED')); + await expect( + emitGithubKnowledgeEvents( + makeEnv(), + 'trial_err', + { owner: 'a', name: 'b' }, + { fetchFn: fetchFn as unknown as typeof fetch } + ) + ).resolves.toBeUndefined(); + expect(emitTrialEventMock).not.toHaveBeenCalled(); + }); + + it('survives a non-2xx repo metadata response without throwing', async () => { + const fetchFn = vi + .fn() + .mockResolvedValue(new Response('rate limited', { status: 403 })); + await expect( + emitGithubKnowledgeEvents( + makeEnv(), + 'trial_403', + { owner: 'a', name: 'b' }, + { fetchFn: fetchFn as unknown as typeof fetch } + ) + ).resolves.toBeUndefined(); + // No observations emitted when all probes fail. + expect(emitTrialEventMock).not.toHaveBeenCalled(); + }); + + it('does not throw when emitTrialEvent rejects (last line of defense)', async () => { + emitTrialEventMock.mockRejectedValueOnce(new Error('bus closed')); + const fetchFn = routedFetch({ + repo: () => jsonResp({ description: 'hello' }), + languages: () => jsonResp({}), + readme: () => new Response('', { status: 404 }), + }); + await expect( + emitGithubKnowledgeEvents( + makeEnv(), + 'trial_ok', + { owner: 'a', name: 'b' }, + { fetchFn: fetchFn as unknown as typeof fetch } + ) + ).resolves.toBeUndefined(); + }); +}); diff --git a/apps/api/tests/unit/services/trial-helpers.test.ts b/apps/api/tests/unit/services/trial-helpers.test.ts new file mode 100644 index 000000000..eb0e0a88c --- /dev/null +++ b/apps/api/tests/unit/services/trial-helpers.test.ts @@ -0,0 +1,221 @@ +/** + * Unit tests for `services/trial/helpers.ts`. + * + * Covers: + * - env-var resolvers (missing / empty / invalid → DEFAULT_*) + * - currentMonthKey / nextMonthResetDate (UTC math across month/year borders) + * - shiftMonthKey (positive and negative deltas, cross-year) + * - parseGithubRepoUrl (happy path, .git suffix, trailing slash, rejects SSH/http/non-github) + */ +import { describe, expect, it } from 'vitest'; + +import type { Env } from '../../../src/env'; +import { + currentMonthKey, + DEFAULT_TRIAL_COUNTER_KEEP_MONTHS, + DEFAULT_TRIAL_GITHUB_TIMEOUT_MS, + DEFAULT_TRIAL_MONTHLY_CAP, + DEFAULT_TRIAL_REPO_MAX_KB, + DEFAULT_TRIAL_WAITLIST_PURGE_DAYS, + DEFAULT_TRIAL_WORKSPACE_TTL_MS, + nextMonthResetDate, + parseGithubRepoUrl, + resolveCounterKeepMonths, + resolveGithubTimeoutMs, + resolveMonthlyCap, + resolveRepoMaxKb, + resolveWaitlistPurgeDays, + resolveWorkspaceTtlMs, + shiftMonthKey, +} from '../../../src/services/trial/helpers'; + +function envWith(overrides: Partial = {}): Env { + return { ...overrides } as Env; +} + +describe('trial helpers — env-var resolvers', () => { + it('resolveMonthlyCap: missing → DEFAULT', () => { + expect(resolveMonthlyCap(envWith())).toBe(DEFAULT_TRIAL_MONTHLY_CAP); + }); + + it('resolveMonthlyCap: empty string → DEFAULT', () => { + expect(resolveMonthlyCap(envWith({ TRIAL_MONTHLY_CAP: '' }))).toBe( + DEFAULT_TRIAL_MONTHLY_CAP + ); + }); + + it('resolveMonthlyCap: valid numeric string → parsed', () => { + expect(resolveMonthlyCap(envWith({ TRIAL_MONTHLY_CAP: '25' }))).toBe(25); + }); + + it('resolveMonthlyCap: 0 is respected (disables cap)', () => { + expect(resolveMonthlyCap(envWith({ TRIAL_MONTHLY_CAP: '0' }))).toBe(0); + }); + + it('resolveMonthlyCap: junk → DEFAULT', () => { + expect(resolveMonthlyCap(envWith({ TRIAL_MONTHLY_CAP: 'abc' }))).toBe( + DEFAULT_TRIAL_MONTHLY_CAP + ); + }); + + it('resolveWorkspaceTtlMs: 0 is REJECTED → DEFAULT (TTL must be > 0)', () => { + expect(resolveWorkspaceTtlMs(envWith({ TRIAL_WORKSPACE_TTL_MS: '0' }))).toBe( + DEFAULT_TRIAL_WORKSPACE_TTL_MS + ); + }); + + it('resolveWorkspaceTtlMs: positive override is used', () => { + expect( + resolveWorkspaceTtlMs(envWith({ TRIAL_WORKSPACE_TTL_MS: '60000' })) + ).toBe(60000); + }); + + it('resolveRepoMaxKb: DEFAULT when missing', () => { + expect(resolveRepoMaxKb(envWith())).toBe(DEFAULT_TRIAL_REPO_MAX_KB); + }); + + it('resolveRepoMaxKb: override when valid', () => { + expect(resolveRepoMaxKb(envWith({ TRIAL_REPO_MAX_KB: '1024' }))).toBe(1024); + }); + + it('resolveGithubTimeoutMs: DEFAULT when missing', () => { + expect(resolveGithubTimeoutMs(envWith())).toBe( + DEFAULT_TRIAL_GITHUB_TIMEOUT_MS + ); + }); + + it('resolveCounterKeepMonths: floors fractional overrides', () => { + expect( + resolveCounterKeepMonths(envWith({ TRIAL_COUNTER_KEEP_MONTHS: '3.9' })) + ).toBe(3); + }); + + it('resolveCounterKeepMonths: DEFAULT when missing', () => { + expect(resolveCounterKeepMonths(envWith())).toBe( + DEFAULT_TRIAL_COUNTER_KEEP_MONTHS + ); + }); + + it('resolveWaitlistPurgeDays: DEFAULT when missing', () => { + expect(resolveWaitlistPurgeDays(envWith())).toBe( + DEFAULT_TRIAL_WAITLIST_PURGE_DAYS + ); + }); + + it('resolveWaitlistPurgeDays: negative or zero → DEFAULT', () => { + expect( + resolveWaitlistPurgeDays(envWith({ TRIAL_WAITLIST_PURGE_DAYS: '0' })) + ).toBe(DEFAULT_TRIAL_WAITLIST_PURGE_DAYS); + expect( + resolveWaitlistPurgeDays(envWith({ TRIAL_WAITLIST_PURGE_DAYS: '-1' })) + ).toBe(DEFAULT_TRIAL_WAITLIST_PURGE_DAYS); + }); +}); + +describe('trial helpers — month key math (UTC)', () => { + it('currentMonthKey returns YYYY-MM for a given timestamp', () => { + // 2026-04-18T12:00:00Z + const t = Date.UTC(2026, 3, 18, 12, 0, 0); + expect(currentMonthKey(t)).toBe('2026-04'); + }); + + it('currentMonthKey zero-pads single-digit months', () => { + const t = Date.UTC(2026, 0, 1, 0, 0, 0); // Jan + expect(currentMonthKey(t)).toBe('2026-01'); + }); + + it('currentMonthKey uses UTC (not local timezone)', () => { + // 2026-04-30T23:30:00 UTC is still April. + const t = Date.UTC(2026, 3, 30, 23, 30, 0); + expect(currentMonthKey(t)).toBe('2026-04'); + }); + + it('nextMonthResetDate rolls into next month', () => { + const t = Date.UTC(2026, 3, 18, 12, 0, 0); + expect(nextMonthResetDate(t)).toBe('2026-05-01'); + }); + + it('nextMonthResetDate rolls into next year on December', () => { + const t = Date.UTC(2026, 11, 25, 0, 0, 0); // 2026-12-25 + expect(nextMonthResetDate(t)).toBe('2027-01-01'); + }); + + it('shiftMonthKey handles positive delta', () => { + expect(shiftMonthKey('2026-04', 2)).toBe('2026-06'); + }); + + it('shiftMonthKey handles negative delta across year boundary', () => { + expect(shiftMonthKey('2026-02', -3)).toBe('2025-11'); + }); + + it('shiftMonthKey delta=0 is identity', () => { + expect(shiftMonthKey('2026-04', 0)).toBe('2026-04'); + }); + + it('shiftMonthKey rejects malformed keys', () => { + expect(() => shiftMonthKey('2026/04', 0)).toThrow(); + expect(() => shiftMonthKey('abc', 0)).toThrow(); + }); +}); + +describe('trial helpers — parseGithubRepoUrl', () => { + it('parses a canonical URL', () => { + expect(parseGithubRepoUrl('https://github.com/owner/repo')).toEqual({ + owner: 'owner', + name: 'repo', + canonical: 'https://github.com/owner/repo', + }); + }); + + it('strips .git suffix', () => { + expect(parseGithubRepoUrl('https://github.com/owner/repo.git')).toEqual({ + owner: 'owner', + name: 'repo', + canonical: 'https://github.com/owner/repo', + }); + }); + + it('strips trailing slash', () => { + expect(parseGithubRepoUrl('https://github.com/owner/repo/')).toEqual({ + owner: 'owner', + name: 'repo', + canonical: 'https://github.com/owner/repo', + }); + }); + + it('handles names with dots and dashes', () => { + const parsed = parseGithubRepoUrl('https://github.com/owner/my.repo-name_v2'); + expect(parsed).toEqual({ + owner: 'owner', + name: 'my.repo-name_v2', + canonical: 'https://github.com/owner/my.repo-name_v2', + }); + }); + + it('rejects SSH URLs', () => { + expect(parseGithubRepoUrl('git@github.com:owner/repo.git')).toBeNull(); + }); + + it('rejects HTTP (non-HTTPS) URLs', () => { + expect(parseGithubRepoUrl('http://github.com/owner/repo')).toBeNull(); + }); + + it('rejects non-github hosts', () => { + expect(parseGithubRepoUrl('https://gitlab.com/owner/repo')).toBeNull(); + }); + + it('rejects paths deeper than owner/repo', () => { + expect( + parseGithubRepoUrl('https://github.com/owner/repo/tree/main') + ).toBeNull(); + }); + + it('rejects empty strings', () => { + expect(parseGithubRepoUrl('')).toBeNull(); + }); + + it('rejects owner with leading or trailing hyphen', () => { + expect(parseGithubRepoUrl('https://github.com/-owner/repo')).toBeNull(); + expect(parseGithubRepoUrl('https://github.com/owner-/repo')).toBeNull(); + }); +}); diff --git a/apps/api/tests/unit/services/trial-kill-switch.test.ts b/apps/api/tests/unit/services/trial-kill-switch.test.ts new file mode 100644 index 000000000..293ce2a75 --- /dev/null +++ b/apps/api/tests/unit/services/trial-kill-switch.test.ts @@ -0,0 +1,131 @@ +/** + * Unit tests for the trial kill-switch. + * + * Covers: + * - Default (no KV value) → disabled + * - KV returns "true" → enabled + * - KV returns anything else → disabled + * - Cache short-circuits subsequent reads within TTL + * - Cache entry expires after TTL and KV is re-read + * - KV read errors → fails CLOSED (disabled) + * - Custom KV key via env.TRIALS_ENABLED_KV_KEY + * - Custom TTL via env.TRIAL_KILL_SWITCH_CACHE_MS + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// Silence the structured logger so KV-failure tests don't pollute output. +// `vi.hoisted` runs before `vi.mock` calls are hoisted so we can safely reference +// the spy inside the mock factory AND the test bodies. +const { mockLogError } = vi.hoisted(() => ({ mockLogError: vi.fn() })); +vi.mock('../../../src/lib/logger', () => ({ + log: { + info: vi.fn(), + warn: vi.fn(), + error: mockLogError, + }, +})); + +import type { Env } from '../../../src/env'; +import { + __resetKillSwitchCacheForTest, + isTrialsEnabled, +} from '../../../src/services/trial/kill-switch'; + +function makeEnv( + kv: { get: (key: string) => Promise }, + overrides: Partial = {} +): Env { + return { + KV: kv as unknown as KVNamespace, + ...overrides, + } as unknown as Env; +} + +describe('trial kill-switch', () => { + beforeEach(() => { + __resetKillSwitchCacheForTest(); + }); + + it('defaults to disabled when the KV key is unset (null)', async () => { + const get = vi.fn().mockResolvedValue(null); + const env = makeEnv({ get }); + expect(await isTrialsEnabled(env, 0)).toBe(false); + expect(get).toHaveBeenCalledWith('trials:enabled'); + }); + + it('returns enabled when KV returns the exact string "true"', async () => { + const get = vi.fn().mockResolvedValue('true'); + expect(await isTrialsEnabled(makeEnv({ get }), 0)).toBe(true); + }); + + it('returns disabled when KV returns any other string', async () => { + for (const value of ['false', '1', 'TRUE', 'enabled', '']) { + __resetKillSwitchCacheForTest(); + const get = vi.fn().mockResolvedValue(value); + expect(await isTrialsEnabled(makeEnv({ get }), 0)).toBe(false); + } + }); + + it('caches the value within TTL and skips subsequent KV reads', async () => { + const get = vi.fn().mockResolvedValue('true'); + const env = makeEnv({ get }); + + expect(await isTrialsEnabled(env, 0)).toBe(true); + // Well inside the default 30s TTL. + expect(await isTrialsEnabled(env, 10_000)).toBe(true); + expect(await isTrialsEnabled(env, 29_999)).toBe(true); + expect(get).toHaveBeenCalledTimes(1); + }); + + it('re-reads KV when the cache entry has expired', async () => { + const get = vi.fn().mockResolvedValueOnce('true').mockResolvedValueOnce('false'); + const env = makeEnv({ get }); + + expect(await isTrialsEnabled(env, 0)).toBe(true); + // Past the default 30s TTL. + expect(await isTrialsEnabled(env, 31_000)).toBe(false); + expect(get).toHaveBeenCalledTimes(2); + }); + + it('fails closed (disabled) when KV.get throws', async () => { + mockLogError.mockClear(); + const get = vi.fn().mockRejectedValue(new Error('kv-outage')); + expect(await isTrialsEnabled(makeEnv({ get }), 0)).toBe(false); + expect(mockLogError).toHaveBeenCalledWith( + 'trial.kill_switch.kv_read_failed', + expect.objectContaining({ error: 'kv-outage' }) + ); + }); + + it('caches the fail-closed value so a KV outage does not DoS the hot path', async () => { + const get = vi.fn().mockRejectedValue(new Error('kv-outage')); + const env = makeEnv({ get }); + expect(await isTrialsEnabled(env, 0)).toBe(false); + expect(await isTrialsEnabled(env, 1_000)).toBe(false); + expect(get).toHaveBeenCalledTimes(1); + }); + + it('honours a custom KV key from env.TRIALS_ENABLED_KV_KEY', async () => { + const get = vi.fn().mockResolvedValue('true'); + const env = makeEnv({ get }, { + TRIALS_ENABLED_KV_KEY: 'custom:trials:flag', + } as unknown as Partial); + + expect(await isTrialsEnabled(env, 0)).toBe(true); + expect(get).toHaveBeenCalledWith('custom:trials:flag'); + }); + + it('honours a custom TTL from env.TRIAL_KILL_SWITCH_CACHE_MS', async () => { + const get = vi.fn().mockResolvedValueOnce('true').mockResolvedValueOnce('false'); + const env = makeEnv({ get }, { + TRIAL_KILL_SWITCH_CACHE_MS: '1000', + } as unknown as Partial); + + expect(await isTrialsEnabled(env, 0)).toBe(true); + // Within 1s TTL → cached + expect(await isTrialsEnabled(env, 500)).toBe(true); + // Past 1s TTL → re-read + expect(await isTrialsEnabled(env, 1_500)).toBe(false); + expect(get).toHaveBeenCalledTimes(2); + }); +}); diff --git a/apps/api/tests/unit/services/trial-oauth-hook.test.ts b/apps/api/tests/unit/services/trial-oauth-hook.test.ts new file mode 100644 index 000000000..0141e2fb2 --- /dev/null +++ b/apps/api/tests/unit/services/trial-oauth-hook.test.ts @@ -0,0 +1,261 @@ +/** + * Unit tests for the OAuth-callback trial-claim hook. + * + * Covers: + * - Non-callback URL → response returned untouched + * - Non-2xx/3xx response → untouched + * - Missing fingerprint cookie → untouched + * - Fingerprint cookie with bad signature → untouched + * - No trial record matches fingerprint → untouched + * - Claimed trial → untouched + * - Expired trial → untouched + * - Secret unset (TRIAL_CLAIM_TOKEN_SECRET missing) → untouched + * - Happy path → sets claim cookie, rewrites Location to + * https://app.${BASE_DOMAIN}/try/${trialId}?claim=1 + */ +import { describe, expect, it, vi } from 'vitest'; + +// Silence the logger +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +// Mock the trial-store read by fingerprint +const { readTrialByFingerprintMock } = vi.hoisted(() => ({ + readTrialByFingerprintMock: vi.fn(), +})); +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrialByFingerprint: readTrialByFingerprintMock, +})); + +import type { Env } from '../../../src/env'; +import { + buildFingerprintCookie, + signClaimToken, + signFingerprint, + verifyClaimToken, +} from '../../../src/services/trial/cookies'; +import { maybeAttachTrialClaimCookie } from '../../../src/services/trial/oauth-hook'; + +const SECRET = 'test-secret-at-least-32-bytes-long-for-hmac-verification'; + +function envWith(overrides: Partial = {}): Env { + return { + TRIAL_CLAIM_TOKEN_SECRET: SECRET, + BASE_DOMAIN: 'example.com', + ...overrides, + } as unknown as Env; +} + +function makeRedirectResponse(location = 'https://app.example.com/'): Response { + return new Response(null, { + status: 302, + headers: { Location: location }, + }); +} + +function makeRequestWithFingerprint( + path: string, + signedCookie: string | null +): Request { + const headers = new Headers(); + if (signedCookie) { + headers.set('cookie', `sam_trial_fingerprint=${encodeURIComponent(signedCookie)}`); + } + return new Request(`https://api.example.com${path}`, { headers }); +} + +describe('maybeAttachTrialClaimCookie — bail-out cases', () => { + it('returns response untouched when URL is not /callback/github', async () => { + const env = envWith(); + const req = new Request('https://api.example.com/api/auth/session'); + const resp = makeRedirectResponse(); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when status is 4xx', async () => { + const env = envWith(); + const req = new Request('https://api.example.com/api/auth/callback/github'); + const resp = new Response('bad', { status: 400 }); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when status is 5xx', async () => { + const env = envWith(); + const req = new Request('https://api.example.com/api/auth/callback/github'); + const resp = new Response('err', { status: 500 }); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when TRIAL_CLAIM_TOKEN_SECRET is unset', async () => { + const env = envWith({ TRIAL_CLAIM_TOKEN_SECRET: undefined } as Partial); + const req = makeRequestWithFingerprint('/api/auth/callback/github', 'fake.sig'); + const resp = makeRedirectResponse(); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when no fingerprint cookie present', async () => { + const env = envWith(); + const req = makeRequestWithFingerprint('/api/auth/callback/github', null); + const resp = makeRedirectResponse(); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when fingerprint signature is invalid', async () => { + const env = envWith(); + const req = makeRequestWithFingerprint( + '/api/auth/callback/github', + 'bogus-uuid.not-a-real-signature' + ); + const resp = makeRedirectResponse(); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when no trial record matches fingerprint', async () => { + const env = envWith(); + const signed = await signFingerprint('fp-uuid-1', SECRET); + const req = makeRequestWithFingerprint('/api/auth/callback/github', signed); + const resp = makeRedirectResponse(); + + readTrialByFingerprintMock.mockResolvedValueOnce(null); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when trial is already claimed', async () => { + const env = envWith(); + const signed = await signFingerprint('fp-uuid-2', SECRET); + const req = makeRequestWithFingerprint('/api/auth/callback/github', signed); + const resp = makeRedirectResponse(); + + readTrialByFingerprintMock.mockResolvedValueOnce({ + trialId: 'trial_x', + projectId: 'proj_x', + fingerprint: 'fp-uuid-2', + claimed: true, + expiresAt: Date.now() + 60_000, + }); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); + + it('returns response untouched when trial has expired', async () => { + const env = envWith(); + const signed = await signFingerprint('fp-uuid-3', SECRET); + const req = makeRequestWithFingerprint('/api/auth/callback/github', signed); + const resp = makeRedirectResponse(); + + readTrialByFingerprintMock.mockResolvedValueOnce({ + trialId: 'trial_y', + projectId: 'proj_y', + fingerprint: 'fp-uuid-3', + claimed: false, + expiresAt: Date.now() - 1000, // already expired + }); + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).toBe(resp); + }); +}); + +describe('maybeAttachTrialClaimCookie — happy path', () => { + it('sets claim cookie and rewrites Location to app.${BASE_DOMAIN}/try/:trialId?claim=1', async () => { + const env = envWith({ BASE_DOMAIN: 'sammy.party' } as Partial); + const signed = await signFingerprint('fp-uuid-99', SECRET); + const req = makeRequestWithFingerprint('/api/auth/callback/github', signed); + const resp = makeRedirectResponse('https://api.sammy.party/callback-result'); + + readTrialByFingerprintMock.mockResolvedValueOnce({ + trialId: 'trial_zzz', + projectId: 'proj_zzz', + fingerprint: 'fp-uuid-99', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result).not.toBe(resp); + expect(result.status).toBe(302); + expect(result.headers.get('Location')).toBe( + 'https://app.sammy.party/try/trial_zzz?claim=1' + ); + + // Set-Cookie must contain a claim cookie (sam_trial_claim=...) + const setCookie = result.headers.get('Set-Cookie'); + expect(setCookie).toContain('sam_trial_claim='); + + // Extract token from cookie to verify it's a valid signed claim + const match = /sam_trial_claim=([^;]+)/.exec(setCookie ?? ''); + expect(match).not.toBeNull(); + const token = decodeURIComponent(match![1]!); + const verified = await verifyClaimToken(token, SECRET); + expect(verified.ok).toBe(true); + if (verified.ok) { + expect(verified.payload.trialId).toBe('trial_zzz'); + expect(verified.payload.projectId).toBe('proj_zzz'); + } + }); + + it('preserves other response headers when rewriting the response', async () => { + const env = envWith({ BASE_DOMAIN: 'sammy.party' } as Partial); + const signed = await signFingerprint('fp-uuid-100', SECRET); + const req = makeRequestWithFingerprint('/api/auth/callback/github', signed); + const resp = new Response(null, { + status: 302, + headers: { + Location: '/somewhere', + 'X-Custom-Header': 'preserved', + }, + }); + + readTrialByFingerprintMock.mockResolvedValueOnce({ + trialId: 'trial_h', + projectId: 'proj_h', + fingerprint: 'fp-uuid-100', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result.headers.get('X-Custom-Header')).toBe('preserved'); + }); +}); + +// Sanity: confirm `signFingerprint + buildFingerprintCookie` produce a value +// that survives our test helper's `decodeURIComponent` trip through a cookie. +describe('maybeAttachTrialClaimCookie — cookie parsing', () => { + it('parses a full Set-Cookie-style fingerprint header correctly', async () => { + const env = envWith(); + const signed = await signFingerprint('fp-uuid-cookie', SECRET); + const cookieHeader = buildFingerprintCookie(signed, { secure: false }); + + // Build a request whose Cookie header is the just-built Set-Cookie value's + // name=value portion only (simulating browser behavior). + const value = cookieHeader.split(';')[0]!; // "sam_trial_fingerprint=" + const req = new Request('https://api.example.com/api/auth/callback/github', { + headers: { cookie: value }, + }); + const resp = makeRedirectResponse('https://api.example.com/'); + + readTrialByFingerprintMock.mockResolvedValueOnce({ + trialId: 'trial_cookie', + projectId: 'proj_cookie', + fingerprint: 'fp-uuid-cookie', + claimed: false, + expiresAt: Date.now() + 60_000, + }); + + const result = await maybeAttachTrialClaimCookie(env, req, resp); + expect(result.headers.get('Location')).toContain('/try/trial_cookie?claim=1'); + // Ensure the trial-store was queried with the DECODED UUID + expect(readTrialByFingerprintMock).toHaveBeenCalledWith(env, 'fp-uuid-cookie'); + }); +}); + +// Also sanity-check the signClaimToken import so lint doesn't complain. +signClaimToken; diff --git a/apps/api/tests/unit/services/trial-runner.test.ts b/apps/api/tests/unit/services/trial-runner.test.ts new file mode 100644 index 000000000..b302e6fe3 --- /dev/null +++ b/apps/api/tests/unit/services/trial-runner.test.ts @@ -0,0 +1,300 @@ +/** + * Unit tests for trial-runner. + * + * Covers: + * - resolveTrialRunnerConfig — mode resolution (staging/production), + * agentType default per mode, provider override, model override + * - startDiscoveryAgent — creates chat + ACP session with discovery prompt + * - Production + anthropic provider requires ANTHROPIC_API_KEY_TRIAL + * - emitTrialEvent — appends to TrialEventBus DO stub, no-ops on error + * - emitTrialEventForProject — looks up trial by project then appends + */ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +// Mock project-data service functions used by startDiscoveryAgent +const { createSessionMock, createAcpSessionMock } = vi.hoisted(() => ({ + createSessionMock: vi.fn(), + createAcpSessionMock: vi.fn(), +})); +vi.mock('../../../src/services/project-data', () => ({ + createSession: createSessionMock, + createAcpSession: createAcpSessionMock, +})); + +// Mock trial-store's readTrialByProject used by emitTrialEventForProject +const { readTrialByProjectMock, readTrialMock } = vi.hoisted(() => ({ + readTrialByProjectMock: vi.fn(), + readTrialMock: vi.fn(), +})); +vi.mock('../../../src/services/trial/trial-store', () => ({ + readTrialByProject: readTrialByProjectMock, + readTrial: readTrialMock, +})); + +// Silence logs +vi.mock('../../../src/lib/logger', () => ({ + log: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +import type { Env } from '../../../src/env'; +import { + emitTrialEvent, + emitTrialEventForProject, + resolveTrialRunnerConfig, + startDiscoveryAgent, +} from '../../../src/services/trial/trial-runner'; + +function envBase(overrides: Partial = {}): Env { + return { + ENVIRONMENT: undefined, + ...overrides, + } as unknown as Env; +} + +describe('trial-runner — resolveTrialRunnerConfig', () => { + it('defaults to staging mode when ENVIRONMENT is unset', () => { + const cfg = resolveTrialRunnerConfig(envBase()); + expect(cfg.mode).toBe('staging'); + expect(cfg.agentType).toBe('opencode'); + expect(cfg.provider).toBe('workers-ai'); + }); + + it('detects production mode from ENVIRONMENT=production', () => { + const cfg = resolveTrialRunnerConfig(envBase({ ENVIRONMENT: 'production' } as Partial)); + expect(cfg.mode).toBe('production'); + expect(cfg.agentType).toBe('claude-code'); + expect(cfg.provider).toBe('anthropic'); + }); + + it('detects production mode from ENVIRONMENT=prod', () => { + const cfg = resolveTrialRunnerConfig(envBase({ ENVIRONMENT: 'prod' } as Partial)); + expect(cfg.mode).toBe('production'); + }); + + it('honours TRIAL_AGENT_TYPE_STAGING override', () => { + const cfg = resolveTrialRunnerConfig( + envBase({ TRIAL_AGENT_TYPE_STAGING: 'codex' } as Partial) + ); + expect(cfg.agentType).toBe('codex'); + }); + + it('honours TRIAL_AGENT_TYPE_PRODUCTION override', () => { + const cfg = resolveTrialRunnerConfig( + envBase({ + ENVIRONMENT: 'production', + TRIAL_AGENT_TYPE_PRODUCTION: 'gemini', + } as Partial) + ); + expect(cfg.agentType).toBe('gemini'); + }); + + it('honours TRIAL_LLM_PROVIDER override', () => { + const cfg = resolveTrialRunnerConfig( + envBase({ TRIAL_LLM_PROVIDER: 'anthropic' } as Partial) + ); + expect(cfg.provider).toBe('anthropic'); + }); + + it('honours TRIAL_MODEL override', () => { + const cfg = resolveTrialRunnerConfig( + envBase({ + ENVIRONMENT: 'production', + TRIAL_MODEL: 'claude-opus-4-6', + } as Partial) + ); + expect(cfg.model).toBe('claude-opus-4-6'); + }); + + it('falls back to workers-ai model default when provider=workers-ai', () => { + const cfg = resolveTrialRunnerConfig(envBase()); + expect(cfg.model).toContain('llama'); + }); + + it('falls back to anthropic model default when provider=anthropic', () => { + const cfg = resolveTrialRunnerConfig( + envBase({ ENVIRONMENT: 'production' } as Partial) + ); + expect(cfg.model).toContain('claude'); + }); + + it('ignores invalid provider override and uses mode default', () => { + const cfg = resolveTrialRunnerConfig( + envBase({ TRIAL_LLM_PROVIDER: 'garbage-provider' } as Partial) + ); + expect(cfg.provider).toBe('workers-ai'); + }); +}); + +describe('trial-runner — startDiscoveryAgent', () => { + beforeEach(() => { + createSessionMock.mockReset(); + createAcpSessionMock.mockReset(); + createSessionMock.mockResolvedValue('chat_session_1'); + createAcpSessionMock.mockResolvedValue({ id: 'acp_session_1' }); + }); + + it('creates chat + ACP session with discovery prompt and returns resolved config', async () => { + const env = envBase(); + const result = await startDiscoveryAgent(env, { + projectId: 'proj_1', + workspaceId: 'ws_1', + sessionTopic: 'acme/repo', + }); + + expect(createSessionMock).toHaveBeenCalledWith( + env, + 'proj_1', + 'ws_1', + 'acme/repo', + null + ); + expect(createAcpSessionMock).toHaveBeenCalledTimes(1); + const [, projectId, chatSessionId, prompt, agentType] = createAcpSessionMock.mock.calls[0]; + expect(projectId).toBe('proj_1'); + expect(chatSessionId).toBe('chat_session_1'); + expect(typeof prompt).toBe('string'); + expect(prompt.length).toBeGreaterThan(0); + expect(agentType).toBe('opencode'); // staging default + + expect(result.chatSessionId).toBe('chat_session_1'); + expect(result.acpSessionId).toBe('acp_session_1'); + expect(result.agentType).toBe('opencode'); + expect(result.provider).toBe('workers-ai'); + expect(result.promptVersion).toBeTruthy(); + }); + + it('defaults sessionTopic to "Exploring repository" when not provided', async () => { + await startDiscoveryAgent(envBase(), { + projectId: 'p', + workspaceId: 'w', + }); + expect(createSessionMock).toHaveBeenCalledWith( + expect.anything(), + 'p', + 'w', + 'Exploring repository', + null + ); + }); + + it('throws when production + anthropic provider but ANTHROPIC_API_KEY_TRIAL unset', async () => { + const env = envBase({ ENVIRONMENT: 'production' } as Partial); + await expect( + startDiscoveryAgent(env, { projectId: 'p', workspaceId: 'w' }) + ).rejects.toThrow(/ANTHROPIC_API_KEY_TRIAL/); + }); + + it('succeeds when production + anthropic + key is present', async () => { + const env = envBase({ + ENVIRONMENT: 'production', + ANTHROPIC_API_KEY_TRIAL: 'sk-test', + } as Partial); + const result = await startDiscoveryAgent(env, { + projectId: 'p', + workspaceId: 'w', + }); + expect(result.agentType).toBe('claude-code'); + expect(result.provider).toBe('anthropic'); + }); +}); + +describe('trial-runner — emitTrialEvent', () => { + it('appends to TrialEventBus DO via /append', async () => { + const fetchStub = vi.fn(async () => new Response(JSON.stringify({ cursor: 1 }), { + status: 200, + headers: { 'content-type': 'application/json' }, + })); + const stub = { fetch: fetchStub }; + const env = { + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => stub), + }, + } as unknown as Env; + + await emitTrialEvent(env, 'trial_abc', { + type: 'trial.progress', + message: 'cloning', + at: Date.now(), + } as never); + + expect(fetchStub).toHaveBeenCalledTimes(1); + const [url, init] = fetchStub.mock.calls[0]; + expect(String(url)).toContain('/append'); + expect((init as RequestInit).method).toBe('POST'); + }); + + it('silently swallows append errors (best-effort)', async () => { + const stub = { + fetch: vi.fn(async () => { + throw new Error('network'); + }), + }; + const env = { + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => stub), + }, + } as unknown as Env; + + await expect( + emitTrialEvent(env, 'trial_abc', { type: 'trial.progress', message: 'x', at: 0 } as never) + ).resolves.toBeUndefined(); + }); + + it('ignores 409 closed responses from DO', async () => { + const fetchStub = vi.fn(async () => new Response('closed', { status: 409 })); + const env = { + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => ({ fetch: fetchStub })), + }, + } as unknown as Env; + + await expect( + emitTrialEvent(env, 'trial_abc', { type: 'trial.ready', at: 0 } as never) + ).resolves.toBeUndefined(); + }); +}); + +describe('trial-runner — emitTrialEventForProject', () => { + beforeEach(() => { + readTrialByProjectMock.mockReset(); + }); + + it('no-ops when no trial record exists for the project', async () => { + readTrialByProjectMock.mockResolvedValue(null); + const env = {} as Env; + await emitTrialEventForProject(env, 'proj_nope', { + type: 'trial.progress', + message: 'x', + at: 0, + } as never); + // Nothing to assert — should just not throw. + }); + + it('emits to trialId when a record exists for the project', async () => { + readTrialByProjectMock.mockResolvedValue({ + trialId: 'trial_abc', + projectId: 'proj_1', + }); + const fetchStub = vi.fn(async () => + new Response(JSON.stringify({ cursor: 1 }), { status: 200 }) + ); + const env = { + TRIAL_EVENT_BUS: { + idFromName: vi.fn(() => 'do-id'), + get: vi.fn(() => ({ fetch: fetchStub })), + }, + } as unknown as Env; + + await emitTrialEventForProject(env, 'proj_1', { + type: 'trial.progress', + message: 'x', + at: 0, + } as never); + + expect(readTrialByProjectMock).toHaveBeenCalledWith(env, 'proj_1'); + expect(fetchStub).toHaveBeenCalled(); + }); +}); diff --git a/apps/api/tests/unit/services/trial-store.test.ts b/apps/api/tests/unit/services/trial-store.test.ts new file mode 100644 index 000000000..9cd5f2eed --- /dev/null +++ b/apps/api/tests/unit/services/trial-store.test.ts @@ -0,0 +1,162 @@ +/** + * Unit tests for the trial KV store. + * + * Covers: + * - writeTrial persists record + two index keys with derived TTL + * - readTrial / readTrialByProject / readTrialByFingerprint round-trips + * - markTrialClaimed flips claimed=true idempotently + * - TTL floor of 60s when expiresAt is in the past + * - Malformed JSON in KV → readTrial returns null + */ +import { describe, expect, it, vi } from 'vitest'; + +import type { Env } from '../../../src/env'; +import { + markTrialClaimed, + readTrial, + readTrialByFingerprint, + readTrialByProject, + trialByFingerprintKey, + trialByProjectKey, + trialKey, + type TrialRecord, + writeTrial, +} from '../../../src/services/trial/trial-store'; + +function makeEnv() { + const store = new Map(); + const kv = { + get: vi.fn(async (key: string) => store.get(key)?.value ?? null), + put: vi.fn(async (key: string, value: string, opts?: { expirationTtl?: number }) => { + store.set(key, { value, ttl: opts?.expirationTtl }); + }), + delete: vi.fn(async (key: string) => { + store.delete(key); + }), + }; + return { + env: { KV: kv as unknown as KVNamespace } as unknown as Env, + kv, + store, + }; +} + +function baseRecord(overrides: Partial = {}): TrialRecord { + return { + trialId: 'trial_abc', + projectId: 'proj_xyz', + fingerprint: 'fp-uuid-1', + workspaceId: null, + repoUrl: 'https://github.com/foo/bar', + createdAt: 1_000_000, + expiresAt: Date.now() + 60 * 60 * 1000, // 1h from now + claimed: false, + ...overrides, + }; +} + +describe('trial-store — writeTrial', () => { + it('writes the record and two index keys', async () => { + const { env, kv, store } = makeEnv(); + const record = baseRecord(); + + await writeTrial(env, record); + + expect(kv.put).toHaveBeenCalledTimes(3); + expect(store.get(trialKey(record.trialId))!.value).toBe(JSON.stringify(record)); + expect(store.get(trialByProjectKey(record.projectId))!.value).toBe(record.trialId); + expect(store.get(trialByFingerprintKey(record.fingerprint))!.value).toBe(record.trialId); + }); + + it('sets TTL derived from expiresAt', async () => { + const { env, store } = makeEnv(); + const expiresAt = Date.now() + 10 * 60 * 1000; // 10 min + await writeTrial(env, baseRecord({ expiresAt })); + + const ttl = store.get(trialKey('trial_abc'))!.ttl; + expect(ttl).toBeGreaterThanOrEqual(60); + // Within ±5s of 10 min + expect(Math.abs((ttl ?? 0) - 600)).toBeLessThan(5); + }); + + it('floors TTL at 60 seconds when expiresAt is in the past', async () => { + const { env, store } = makeEnv(); + await writeTrial(env, baseRecord({ expiresAt: 0 })); + expect(store.get(trialKey('trial_abc'))!.ttl).toBe(60); + }); +}); + +describe('trial-store — reads', () => { + it('readTrial round-trips record via JSON parse', async () => { + const { env } = makeEnv(); + const record = baseRecord(); + await writeTrial(env, record); + + const recovered = await readTrial(env, record.trialId); + expect(recovered).toEqual(record); + }); + + it('readTrial returns null when key is absent', async () => { + const { env } = makeEnv(); + expect(await readTrial(env, 'missing')).toBeNull(); + }); + + it('readTrial returns null on malformed JSON', async () => { + const { env, kv } = makeEnv(); + (kv.get as ReturnType).mockResolvedValueOnce('{not json'); + expect(await readTrial(env, 'whatever')).toBeNull(); + }); + + it('readTrialByProject follows the index key', async () => { + const { env } = makeEnv(); + const record = baseRecord(); + await writeTrial(env, record); + + const recovered = await readTrialByProject(env, record.projectId); + expect(recovered?.trialId).toBe(record.trialId); + }); + + it('readTrialByProject returns null when index key missing', async () => { + const { env } = makeEnv(); + expect(await readTrialByProject(env, 'nope')).toBeNull(); + }); + + it('readTrialByFingerprint follows the index key', async () => { + const { env } = makeEnv(); + const record = baseRecord({ fingerprint: 'fp-other' }); + await writeTrial(env, record); + + const recovered = await readTrialByFingerprint(env, 'fp-other'); + expect(recovered?.trialId).toBe(record.trialId); + }); +}); + +describe('trial-store — markTrialClaimed', () => { + it('returns null when the trial does not exist', async () => { + const { env } = makeEnv(); + expect(await markTrialClaimed(env, 'nope')).toBeNull(); + }); + + it('flips claimed=true and rewrites the record', async () => { + const { env } = makeEnv(); + await writeTrial(env, baseRecord()); + + const updated = await markTrialClaimed(env, 'trial_abc'); + expect(updated?.claimed).toBe(true); + + const fromStore = await readTrial(env, 'trial_abc'); + expect(fromStore?.claimed).toBe(true); + }); + + it('is idempotent — calling twice leaves claimed=true', async () => { + const { env, kv } = makeEnv(); + await writeTrial(env, baseRecord()); + await markTrialClaimed(env, 'trial_abc'); + + const callsAfterFirst = (kv.put as ReturnType).mock.calls.length; + const updated = await markTrialClaimed(env, 'trial_abc'); + expect(updated?.claimed).toBe(true); + // Second call short-circuits — no additional writes. + expect((kv.put as ReturnType).mock.calls.length).toBe(callsAfterFirst); + }); +}); diff --git a/apps/api/tests/workers/trial-event-bus-sse.test.ts b/apps/api/tests/workers/trial-event-bus-sse.test.ts new file mode 100644 index 000000000..d251045ca --- /dev/null +++ b/apps/api/tests/workers/trial-event-bus-sse.test.ts @@ -0,0 +1,173 @@ +/** + * Capability + regression test: TrialEventBus DO → SSE endpoint wire-up. + * + * The "zero trial.* events on staging" incident (2026-04-19) was caused by + * the SSE endpoint emitting named events (`event: trial.knowledge\ndata:`) + * while browser EventSource consumers only fire `onmessage` for the default + * (unnamed) event. The bytes arrived on the wire — curl saw them — but no + * frontend-visible event was ever dispatched. + * + * This test exercises the full bus → SSE wire format: + * + * 1. Seed a trial record in KV (what `readTrial` reads in events.ts). + * 2. Append an event directly on the TrialEventBus DO (identical to what + * `emitTrialEvent()` does in the Worker for real trials). + * 3. Open the SSE stream via `SELF.fetch` with a valid fingerprint cookie. + * 4. Read the raw stream bytes and assert: + * - HTTP 200 + correct `content-type` + * - At least one `data: {...}` frame + * - No `event:` line anywhere (the regression guard for the incident) + * - The parsed JSON payload round-trips through the bus intact + * + * See `docs/notes/2026-04-19-trial-sse-named-events-postmortem.md`. + */ +import type { TrialEvent } from '@simple-agent-manager/shared'; +import { TRIAL_COOKIE_FINGERPRINT_NAME } from '@simple-agent-manager/shared'; +import { env, SELF } from 'cloudflare:test'; +import { describe, expect, it } from 'vitest'; + +import { signFingerprint } from '../../src/services/trial/cookies'; +import type { TrialRecord } from '../../src/services/trial/trial-store'; +import { writeTrial } from '../../src/services/trial/trial-store'; + +type WorkerEnv = typeof env & { + TRIAL_EVENT_BUS: DurableObjectNamespace; + TRIAL_CLAIM_TOKEN_SECRET: string; + KV: KVNamespace; +}; + +const workerEnv = env as unknown as WorkerEnv; + +function makeTrialRecord(trialId: string, fingerprint: string): TrialRecord { + const now = Date.now(); + return { + trialId, + projectId: '', + fingerprint, + workspaceId: null, + repoUrl: 'https://github.com/example/repo', + createdAt: now, + expiresAt: now + 1000 * 60 * 20, + claimed: false, + }; +} + +async function appendViaBus(trialId: string, event: TrialEvent): Promise { + const id = workerEnv.TRIAL_EVENT_BUS.idFromName(trialId); + const stub = workerEnv.TRIAL_EVENT_BUS.get(id); + const resp = await stub.fetch('https://trial-event-bus/append', { + method: 'POST', + headers: { 'content-type': 'application/json' }, + body: JSON.stringify(event), + }); + expect(resp.status, 'bus append should succeed').toBe(200); +} + +/** + * Read the SSE stream until it has seen at least `minDataFrames` data frames + * or the stream closes / the timeout elapses. Returns the accumulated body. + */ +async function readStreamUntil( + resp: Response, + minDataFrames: number, + timeoutMs = 3000, +): Promise { + if (!resp.body) throw new Error('response has no body'); + const reader = resp.body.getReader(); + const decoder = new TextDecoder(); + let body = ''; + const deadline = Date.now() + timeoutMs; + while (Date.now() < deadline) { + // Count `data: ` occurrences (each SSE data frame starts with it). + const frameCount = (body.match(/^data: /gm) ?? []).length; + if (frameCount >= minDataFrames) break; + const { value, done } = await Promise.race([ + reader.read(), + new Promise<{ value: undefined; done: true }>((resolve) => + setTimeout(() => resolve({ value: undefined, done: true }), 250), + ), + ]); + if (done) break; + if (value) body += decoder.decode(value, { stream: true }); + } + try { + await reader.cancel(); + } catch { + // already closed + } + return body; +} + +describe('TrialEventBus → SSE endpoint (capability)', () => { + it('events published via the bus stream as unnamed SSE frames that EventSource.onmessage can consume', async () => { + const trialId = `trial_${crypto.randomUUID().replace(/-/g, '')}`; + const fingerprintUuid = crypto.randomUUID(); + + // 1. Seed KV so readTrial() resolves the record in events.ts. + await writeTrial(workerEnv as never, makeTrialRecord(trialId, fingerprintUuid)); + + // 2. Publish a real-shaped TrialEvent via the DO before opening the stream + // — tests the immediate-return path of the DO poll loop. + const knowledgeEvent: TrialEvent = { + type: 'trial.knowledge', + key: 'description', + value: 'capability-test description', + at: Date.now(), + }; + await appendViaBus(trialId, knowledgeEvent); + + // 3. Sign the fingerprint cookie and hit the SSE endpoint. + const secret = workerEnv.TRIAL_CLAIM_TOKEN_SECRET; + const signed = await signFingerprint(fingerprintUuid, secret); + const cookie = `${TRIAL_COOKIE_FINGERPRINT_NAME}=${encodeURIComponent(signed)}`; + + const resp = await SELF.fetch( + `https://api.test.example.com/api/trial/${trialId}/events`, + { + method: 'GET', + headers: { cookie, accept: 'text/event-stream' }, + }, + ); + + expect(resp.status).toBe(200); + expect(resp.headers.get('content-type')).toMatch(/text\/event-stream/); + + // 4. Publish a second event AFTER the stream is open to also exercise + // the long-poll wake path. + const progressEvent: TrialEvent = { + type: 'trial.progress', + step: 'capability-check', + at: Date.now(), + }; + await appendViaBus(trialId, progressEvent); + + const body = await readStreamUntil(resp, 2, 3000); + + // --- Regression guard: no named events ------------------------------- + // If ANY `event:` line appears, the named-event bug has returned and + // browser EventSource.onmessage will silently swallow these frames. + const namedEventLines = body.match(/^event: /gm) ?? []; + expect(namedEventLines).toHaveLength(0); + + // --- Capability: frames are parseable and carry the bus payload ------ + const dataLines = body + .split('\n') + .filter((line) => line.startsWith('data: ')) + .map((line) => line.slice('data: '.length)); + + expect(dataLines.length).toBeGreaterThanOrEqual(2); + + const parsed = dataLines.map((line) => JSON.parse(line) as TrialEvent); + const types = parsed.map((e) => e.type); + expect(types).toContain('trial.knowledge'); + expect(types).toContain('trial.progress'); + + const knowledgeFrame = parsed.find( + (e): e is Extract => + e.type === 'trial.knowledge', + ); + expect(knowledgeFrame).toBeDefined(); + expect(knowledgeFrame?.key).toBe('description'); + expect(knowledgeFrame?.value).toBe('capability-test description'); + }); +}); diff --git a/apps/api/vitest.workers.config.ts b/apps/api/vitest.workers.config.ts index 3c1c8f09f..dd1f8f474 100644 --- a/apps/api/vitest.workers.config.ts +++ b/apps/api/vitest.workers.config.ts @@ -31,6 +31,17 @@ export default defineConfig({ ADMIN_LOGS: { className: 'AdminLogs', }, + TRIAL_COUNTER: { + className: 'TrialCounter', + useSQLite: true, + }, + TRIAL_EVENT_BUS: { + className: 'TrialEventBus', + }, + TRIAL_ORCHESTRATOR: { + className: 'TrialOrchestrator', + useSQLite: true, + }, }, bindings: { BASE_DOMAIN: 'test.example.com', @@ -47,6 +58,10 @@ export default defineConfig({ SESSION_IDLE_TIMEOUT_MINUTES: '60', DO_SUMMARY_SYNC_DEBOUNCE_MS: '5000', NODE_HEARTBEAT_STALE_SECONDS: '180', + TRIAL_CLAIM_TOKEN_SECRET: 'test-trial-secret-do-not-use-in-production', + TRIAL_SSE_HEARTBEAT_MS: '60000', + TRIAL_SSE_POLL_TIMEOUT_MS: '500', + TRIAL_SSE_MAX_DURATION_MS: '5000', }, }, }), diff --git a/apps/api/wrangler.toml b/apps/api/wrangler.toml index 6fe121a9b..9b5ec67e8 100644 --- a/apps/api/wrangler.toml +++ b/apps/api/wrangler.toml @@ -36,13 +36,23 @@ HETZNER_BASE_IMAGE = "docker-ce" TASK_RUNNER_WORKSPACE_READY_POLL_INTERVAL_MS = "30000" # AI Inference Proxy (Cloudflare AI Gateway for trial/zero-config users) AI_PROXY_ENABLED = "true" -AI_PROXY_DEFAULT_MODEL = "@cf/meta/llama-4-scout-17b-16e-instruct" +AI_PROXY_DEFAULT_MODEL = "@cf/qwen/qwen3-30b-a3b-fp8" AI_GATEWAY_ID = "sam" # Triggers (Event-Driven Agent Triggers) MAX_TRIGGERS_PER_PROJECT = "10" CRON_MIN_INTERVAL_MINUTES = "15" CRON_TEMPLATE_MAX_LENGTH = "8000" TRIGGER_NAME_MAX_LENGTH = "100" +# Trial Onboarding (zero-friction URL-to-workspace) +# See docs/guides/trial-configuration.md for the full tunable list. +TRIAL_MONTHLY_CAP = "1500" +TRIAL_WORKSPACE_TTL_MS = "1200000" # 20 min +TRIAL_DATA_RETENTION_HOURS = "168" # 7 days +TRIAL_ANONYMOUS_USER_ID = "system_anonymous_trials" +TRIAL_AGENT_TYPE_STAGING = "opencode" +TRIAL_AGENT_TYPE_PRODUCTION = "claude-code" +TRIAL_DEFAULT_WORKSPACE_PROFILE = "lightweight" +TRIALS_ENABLED_KV_KEY = "trials:enabled" # D1 Database [[d1_databases]] @@ -99,6 +109,21 @@ class_name = "NotificationService" name = "CODEX_REFRESH_LOCK" class_name = "CodexRefreshLock" +# Durable Object for trial-onboarding monthly cap counter (singleton; keyed by YYYY-MM) +[[durable_objects.bindings]] +name = "TRIAL_COUNTER" +class_name = "TrialCounter" + +# Durable Object for per-trial SSE event buffering (keyed by trialId) +[[durable_objects.bindings]] +name = "TRIAL_EVENT_BUS" +class_name = "TrialEventBus" + +# Durable Object for alarm-driven trial provisioning (one DO per trialId) +[[durable_objects.bindings]] +name = "TRIAL_ORCHESTRATOR" +class_name = "TrialOrchestrator" + [[migrations]] tag = "v1" new_sqlite_classes = ["ProjectData"] @@ -123,6 +148,18 @@ new_sqlite_classes = ["NotificationService"] tag = "v6" new_classes = ["CodexRefreshLock"] # Intentionally stateless: uses D1 for credential storage, no per-instance SQLite needed +[[migrations]] +tag = "v7" +new_sqlite_classes = ["TrialCounter"] + +[[migrations]] +tag = "v8" +new_classes = ["TrialEventBus"] # In-memory only — no SQLite storage needed for short-lived trial streams + +[[migrations]] +tag = "v9" +new_sqlite_classes = ["TrialOrchestrator"] # Alarm-driven trial provisioning; SQLite-backed per Cloudflare recommendation for new DO classes + # Analytics Engine for usage tracking (Phase 1 — server-side analytics) [[analytics_engine_datasets]] binding = "ANALYTICS" @@ -138,13 +175,19 @@ binding = "AI" # because they break Vitest — Cloudflare issue #9343). # Cron triggers: -# - Every 5 minutes: provisioning timeout checks, warm node cleanup, stuck task recovery +# - Every 5 minutes: provisioning timeout checks, warm node cleanup, stuck task recovery, +# trial expiry sweep # - Daily at 03:00 UTC: analytics event forwarding to external platforms (Phase 4) -# NOTE: The daily cron string "0 3 * * *" is matched exactly in index.ts:scheduled() -# to distinguish the daily forward job from the 5-minute sweep. If you change it here, -# update the comparison in the scheduled handler too. +# - Daily at 04:00 UTC: trial waitlist purge (rows with notified_at older than +# TRIAL_WAITLIST_PURGE_DAYS). Configurable via TRIAL_CRON_WAITLIST_CLEANUP. +# - Monthly on the 1st at 05:00 UTC: trial counter rollover audit (prune DO rows beyond +# TRIAL_COUNTER_KEEP_MONTHS). Configurable via TRIAL_CRON_ROLLOVER_CRON. +# NOTE: Intentionally offset from the 03:00 daily job to avoid collision on the 1st. +# NOTE: Cron strings are matched exactly in index.ts:scheduled() to dispatch between +# the different jobs. If you change them here, update the comparisons in the scheduled +# handler (or the corresponding TRIAL_CRON_* env vars). [triggers] -crons = ["*/5 * * * *", "0 3 * * *"] +crons = ["*/5 * * * *", "0 3 * * *", "0 4 * * *", "0 5 1 * *"] # Secrets (set via wrangler secret put): # - GITHUB_CLIENT_ID @@ -168,3 +211,4 @@ crons = ["*/5 * * * *", "0 3 * * *"] # - GA4_API_SECRET (optional — Google Analytics 4 API secret, enables GA4 forwarding; NOTE: GA4 Measurement Protocol requires api_secret as a query parameter — ensure outbound request URLs are not logged) # - R2_ACCESS_KEY_ID (optional — R2 S3-compatible API token key ID, enables task attachment presigned uploads) # - R2_SECRET_ACCESS_KEY (optional — R2 S3-compatible API token secret, enables task attachment presigned uploads) +# - TRIAL_CLAIM_TOKEN_SECRET (required when trials are enabled — HMAC secret for sam_trial_claim / sam_trial_fingerprint cookies. 32+ bytes base64.) diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index de5753a39..fa0dd6f75 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -9,6 +9,7 @@ import { GlobalAudioProvider } from './contexts/GlobalAudioContext'; import { ToastProvider } from './hooks/useToast'; import { AccountMap } from './pages/AccountMap'; import { Admin } from './pages/Admin'; +import { AdminAIProxy } from './pages/AdminAIProxy'; import { AdminAnalytics } from './pages/AdminAnalytics'; import { AdminComputeQuotas } from './pages/AdminComputeQuotas'; import { AdminComputeUsage } from './pages/AdminComputeUsage'; @@ -46,6 +47,11 @@ import { SettingsGitHub } from './pages/SettingsGitHub'; import { SettingsNotifications } from './pages/SettingsNotifications'; import { SettingsSmokeTestTokens } from './pages/SettingsSmokeTestTokens'; import { TaskDetail } from './pages/TaskDetail'; +import { TrialChatGateHarness } from './pages/TrialChatGateHarness'; +import { Try } from './pages/Try'; +import { TryCapExceeded } from './pages/TryCapExceeded'; +import { TryDiscovery } from './pages/TryDiscovery'; +import { TryWaitlistThanks } from './pages/TryWaitlistThanks'; import { UiStandards } from './pages/UiStandards'; import { Workspace } from './pages/workspace'; import { Workspaces } from './pages/Workspaces'; @@ -71,6 +77,12 @@ export default function App() { {/* Public routes */} } /> + } /> + } /> + } /> + } /> + {/* Harness for Playwright audits — mounts trial components with mock data */} + } /> {/* Protected routes with AppShell (persistent navigation) */} }> @@ -119,6 +131,7 @@ export default function App() { } /> } /> } /> + } /> } /> } /> } /> diff --git a/apps/web/src/components/trial/ChatGate.tsx b/apps/web/src/components/trial/ChatGate.tsx new file mode 100644 index 000000000..66f7d5ff0 --- /dev/null +++ b/apps/web/src/components/trial/ChatGate.tsx @@ -0,0 +1,234 @@ +/** + * ChatGate — suggestion chips + chat input for the trial discovery flow. + * + * Rendered inside {@link TryDiscovery} once the `trial.ready` SSE event has + * fired. Authenticated visitors submit directly to the project's normal chat + * endpoint; anonymous visitors hit the {@link LoginSheet} first. After + * sign-in, the draft (persisted by {@link useTrialDraft}) survives the OAuth + * round-trip and is replayed by {@link useTrialClaim}. + * + * Layout: + * ┌──────────────────────────────────────────────┐ + * │ [chip] [chip] [chip] ... (scrolls) │ + * ├──────────────────────────────────────────────┤ + * │ ┌────────────────────────────────┐ ┌───┐ │ + * │ │ textarea │ │ → │ │ + * │ └────────────────────────────────┘ └───┘ │ + * └──────────────────────────────────────────────┘ + */ +import type { TrialIdea } from '@simple-agent-manager/shared'; +import { useCallback, useRef, useState } from 'react'; + +import { useTrialDraft } from '../../hooks/useTrialDraft'; +import { useAuth } from '../AuthProvider'; +import { LoginSheet } from './LoginSheet'; +import { SuggestionChip } from './SuggestionChip'; + +interface ChatGateProps { + trialId: string; + ideas: TrialIdea[]; + /** + * Called when an authenticated visitor sends a message. The caller is + * responsible for routing the message to the correct project chat endpoint + * (since the project ID is known only to {@link TryDiscovery}). + * + * Must return a promise that resolves on successful submit — the draft is + * cleared after the promise resolves so failed sends don't lose the user's + * text. + */ + onAuthenticatedSubmit: (message: string) => Promise; + /** Optional textarea placeholder. */ + placeholder?: string; + /** Injected for testing — defaults to false unless tests force it. */ + forceAnonymous?: boolean; +} + +export function ChatGate({ + trialId, + ideas, + onAuthenticatedSubmit, + placeholder = 'Ask anything about this repo…', + forceAnonymous = false, +}: ChatGateProps) { + const { isAuthenticated } = useAuth(); + const { draft, setDraft, clearDraft } = useTrialDraft(trialId); + const [showLogin, setShowLogin] = useState(false); + const [isSending, setIsSending] = useState(false); + const [errorMessage, setErrorMessage] = useState(null); + const textareaRef = useRef(null); + + const treatAsAuthenticated = isAuthenticated && !forceAnonymous; + + const handleChipSelect = useCallback( + (idea: TrialIdea) => { + setDraft(idea.prompt); + setErrorMessage(null); + // Focus textarea so the user can immediately edit or press Send. + textareaRef.current?.focus(); + }, + [setDraft], + ); + + const handleSend = useCallback(async () => { + const trimmed = draft.trim(); + if (!trimmed || isSending) return; + + if (!treatAsAuthenticated) { + setShowLogin(true); + return; + } + + setIsSending(true); + setErrorMessage(null); + try { + await onAuthenticatedSubmit(trimmed); + clearDraft(); + } catch (err) { + const message = err instanceof Error ? err.message : 'Failed to send message'; + setErrorMessage(message); + } finally { + setIsSending(false); + } + }, [draft, isSending, treatAsAuthenticated, onAuthenticatedSubmit, clearDraft]); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + // Cmd/Ctrl+Enter submits; plain Enter inserts a newline (matches the + // rest of the product's chat inputs). + if (e.key === 'Enter' && (e.metaKey || e.ctrlKey)) { + e.preventDefault(); + void handleSend(); + } + }, + [handleSend], + ); + + const canSend = draft.trim().length > 0 && !isSending; + + return ( +
+ {ideas.length > 0 && ( +
+

+ Try one of these to get started: +

+
+ {ideas.map((idea) => ( +
+ +
+ ))} +
+
+ )} + + {!treatAsAuthenticated && draft.trim().length === 0 ? ( +

+ Sign in with GitHub to save your trial and keep this conversation going. +

+ ) : null} + +
+ +