diff --git a/.env.example b/.env.example index af023cf..73a6a11 100644 --- a/.env.example +++ b/.env.example @@ -5,19 +5,23 @@ SERVER_HOST=0.0.0.0 SERVER_PORT=8084 -# JWT authentication (required) -JWT_SECRET=mysecret - # PostgreSQL connection (required) DATABASE_DSN=postgres://vultisig:vultisig@localhost:5432/vultisig-agent?sslmode=disable # Redis connection (required) REDIS_URI=redis://localhost:6379 -# Anthropic Claude API (required) -ANTHROPIC_API_KEY=sk-ant-your-key-here -ANTHROPIC_MODEL=claude-sonnet-4-20250514 -ANTHROPIC_SUMMARY_MODEL=claude-haiku-4-5-20251001 +# Auth cache keying (required) +AUTH_CACHE_KEY_SECRET=replace-with-strong-random-secret +AUTH_CACHE_TTL_SECONDS=180 + +# AI Provider - OpenRouter (required) +AI_API_KEY=sk-or-your-key-here +AI_MODEL=anthropic/claude-sonnet-4.5 +AI_SUMMARY_MODEL=anthropic/claude-haiku-4.5 +AI_BASE_URL=https://openrouter.ai/api/v1 +AI_APP_NAME=vultisig-agent +AI_APP_URL=https://vultisig.com # Conversation context window CONTEXT_WINDOW_SIZE=20 diff --git a/.github/workflows/deploy-dev.yaml b/.github/workflows/deploy-dev.yaml index 7172834..f1b1ef2 100644 --- a/.github/workflows/deploy-dev.yaml +++ b/.github/workflows/deploy-dev.yaml @@ -7,14 +7,21 @@ on: paths-ignore: - '**.md' - '.gitignore' + workflow_run: + workflows: ["Sync main to dev"] + types: + - completed env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }} - NS: agent + NS: agent-backend jobs: build-server: + if: > + github.event_name == 'push' || + (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') uses: ./.github/workflows/build-push-image.yaml permissions: contents: read @@ -23,7 +30,7 @@ jobs: service: server registry: ghcr.io image_name: ${{ github.repository }} - tag: dev-${{ github.sha }} + tag: dev-${{ github.event_name == 'workflow_run' && github.event.workflow_run.head_sha || github.sha }} additional_tag: dev-latest deploy: @@ -34,6 +41,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'workflow_run' && 'dev' || github.ref }} - name: Set up kubeconfig run: | @@ -43,7 +52,7 @@ jobs: - name: Update deployment files run: | - TAG="dev-${{ github.sha }}" + TAG="dev-${{ github.event_name == 'workflow_run' && github.event.workflow_run.head_sha || github.sha }}" sed -i "s|image: busybox|image: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}/server:${TAG}|" deploy/01_server.yaml - name: Deploy to Kubernetes diff --git a/.github/workflows/sync-main-to-dev.yaml b/.github/workflows/sync-main-to-dev.yaml new file mode 100644 index 0000000..7ca1fe7 --- /dev/null +++ b/.github/workflows/sync-main-to-dev.yaml @@ -0,0 +1,27 @@ +name: Sync main to dev + +on: + push: + branches: + - main + +jobs: + sync-main-to-dev: + runs-on: ubuntu-latest + permissions: + contents: write + actions: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Sync main to dev + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git fetch origin dev:dev + git checkout dev + git merge origin/main --no-edit + git push origin dev diff --git a/.gitignore b/.gitignore index 89d5387..ffe9094 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ vendor/ .env .env.local .env.*.local +.devenv # Config files (use examples) config.json @@ -43,3 +44,4 @@ Thumbs.db # Debug debug __debug_bin* +agent-backend diff --git a/Dockerfile b/Dockerfile index 84eeb9c..6f8daa0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,7 @@ FROM golang:1.25-bookworm AS builder +ARG SERVICE=server + WORKDIR /app # Copy dependency files first for better caching @@ -10,7 +12,7 @@ RUN go mod download COPY . . # Build static binary -RUN CGO_ENABLED=0 GOOS=linux go build -o main ./cmd/server +RUN CGO_ENABLED=0 GOOS=linux go build -o main ./cmd/${SERVICE} # Runtime image FROM debian:bookworm-slim @@ -23,6 +25,6 @@ WORKDIR /app COPY --from=builder /app/main . -EXPOSE 8080 +EXPOSE 8084 CMD ["./main"] diff --git a/Makefile b/Makefile index 503fc9d..6744392 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,6 @@ -.PHONY: build run test docker-build migrate-up migrate-down lint clean +.PHONY: build run test docker-build migrate-up migrate-down lint clean deploy-prod deploy-configs deploy-server + +NS ?= agent-backend # Binary name BINARY=server @@ -40,3 +42,12 @@ lint: clean: rm -rf bin/ rm -f coverage.out coverage.html + +deploy-prod: deploy-configs deploy-server + +deploy-configs: + kubectl -n $(NS) apply -f deploy/prod + +deploy-server: + kubectl -n $(NS) apply -f deploy/01_server.yaml + kubectl -n $(NS) rollout status deployment/server --timeout=300s diff --git a/README.md b/README.md index 17cc23d..ffad1dd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # agent-backend -AI Chat Agent backend service for Vultisig mobile apps. This service handles natural language conversations using Anthropic Claude and coordinates with existing Vultisig plugins (app-recurring, feeplugin) via the verifier. +AI Chat Agent backend service for Vultisig mobile apps. This service handles natural language conversations using LLMs via OpenRouter and coordinates with existing Vultisig plugins (app-recurring, feeplugin) via the verifier. ## Prerequisites @@ -15,11 +15,16 @@ AI Chat Agent backend service for Vultisig mobile apps. This service handles nat |----------|----------|---------|-------------| | `SERVER_HOST` | No | `0.0.0.0` | Server bind address | | `SERVER_PORT` | No | `8080` | Server port | -| `JWT_SECRET` | Yes | - | Secret for JWT token signing | | `DATABASE_DSN` | Yes | - | PostgreSQL connection string | | `REDIS_URI` | Yes | - | Redis connection URI | -| `ANTHROPIC_API_KEY` | Yes | - | Anthropic Claude API key | -| `ANTHROPIC_MODEL` | No | `claude-sonnet-4-20250514` | Claude model to use | +| `AUTH_CACHE_KEY_SECRET` | Yes | - | HMAC secret for auth cache key derivation | +| `AUTH_CACHE_TTL_SECONDS` | No | `180` | Auth cache TTL (seconds) | +| `AI_API_KEY` | Yes | - | OpenRouter API key | +| `AI_MODEL` | No | `anthropic/claude-sonnet-4.5` | Model to use (OpenRouter format) | +| `AI_SUMMARY_MODEL` | No | `anthropic/claude-haiku-4.5` | Model for conversation summarization | +| `AI_BASE_URL` | No | `https://openrouter.ai/api/v1` | AI provider base URL | +| `AI_APP_NAME` | No | `vultisig-agent` | App name sent to OpenRouter | +| `AI_APP_URL` | No | - | App URL sent to OpenRouter | | `VERIFIER_URL` | Yes | - | Verifier service base URL | | `LOG_FORMAT` | No | `json` | Log format (`json` or `text`) | @@ -28,10 +33,10 @@ AI Chat Agent backend service for Vultisig mobile apps. This service handles nat 1. Set required environment variables: ```bash -export JWT_SECRET="your-jwt-secret" export DATABASE_DSN="postgres://user:pass@localhost:5432/agent?sslmode=disable" export REDIS_URI="redis://localhost:6379" -export ANTHROPIC_API_KEY="sk-ant-..." +export AUTH_CACHE_KEY_SECRET="replace-with-strong-random-secret" +export AI_API_KEY="sk-or-v1-..." export VERIFIER_URL="http://localhost:8080" ``` @@ -60,10 +65,9 @@ Run with Docker: ```bash docker run -p 8080:8080 \ - -e JWT_SECRET="your-jwt-secret" \ -e DATABASE_DSN="postgres://..." \ -e REDIS_URI="redis://..." \ - -e ANTHROPIC_API_KEY="sk-ant-..." \ + -e AI_API_KEY="sk-or-v1-..." \ -e VERIFIER_URL="http://verifier:8080" \ agent-backend:latest ``` @@ -117,7 +121,7 @@ internal/ service/ # Business logic layer storage/postgres/ # PostgreSQL repositories + migrations cache/redis/ # Redis caching - ai/anthropic/ # Anthropic Claude integration + ai/ # AI client (OpenRouter-compatible) config/ # Configuration loading types/ # Shared types ``` diff --git a/cmd/scheduler/main.go b/cmd/scheduler/main.go new file mode 100644 index 0000000..241449d --- /dev/null +++ b/cmd/scheduler/main.go @@ -0,0 +1,119 @@ +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" + + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/ai" + "github.com/vultisig/agent-backend/internal/cache/redis" + "github.com/vultisig/agent-backend/internal/config" + "github.com/vultisig/agent-backend/internal/mcp" + "github.com/vultisig/agent-backend/internal/service/agent" + "github.com/vultisig/agent-backend/internal/service/plugin" + "github.com/vultisig/agent-backend/internal/service/scheduler" + "github.com/vultisig/agent-backend/internal/service/verifier" + "github.com/vultisig/agent-backend/internal/storage/postgres" +) + +func main() { + logger := logrus.New() + logger.SetFormatter(&logrus.JSONFormatter{}) + logger.SetOutput(os.Stdout) + + cfg, err := config.Load() + if err != nil { + logger.WithError(err).Fatal("failed to load configuration") + } + + if cfg.LogFormat == "text" { + logger.SetFormatter(&logrus.TextFormatter{}) + } + + logger.Info("starting agent-scheduler") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Connect to database + db, err := postgres.New(ctx, cfg.Database.DSN) + if err != nil { + logger.WithError(err).Fatal("failed to connect to database") + } + defer db.Close() + + // Initialize Redis client + redisClient, err := redis.New(cfg.Redis.URI) + if err != nil { + logger.WithError(err).Fatal("failed to connect to redis") + } + defer redisClient.Close() + + // Initialize AI client + aiClient := ai.NewClient(cfg.AI.APIKey, cfg.AI.Model, cfg.AI.BaseURL, cfg.AI.AppName, cfg.AI.AppURL) + + // Initialize services + pluginService := plugin.NewService(cfg.Verifier.URL, redisClient, logger) + verifierClient := verifier.NewClient(cfg.Verifier.URL) + + // Initialize MCP client (optional, provides observation tools like find_token, get_balance) + var mcpProvider agent.MCPToolProvider + if cfg.MCP.ServerURL != "" { + cacheTTL := time.Duration(cfg.MCP.ToolCacheTTLSec) * time.Second + mcpClient := mcp.NewClient(cfg.MCP.ServerURL, cacheTTL, logger) + + mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second) + defer mcpCancel() + + if err := mcpClient.Initialize(mcpCtx); err != nil { + logger.WithError(err).Warn("failed to initialize mcp client, continuing without mcp tools") + } else { + tools, err := mcpClient.ListTools(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp tools, continuing without mcp tools") + } else { + logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") + mcpProvider = mcpClient + } + } + } + + // Initialize repositories + convRepo := postgres.NewConversationRepository(db.Pool()) + msgRepo := postgres.NewMessageRepository(db.Pool()) + memRepo := postgres.NewMemoryRepository(db.Pool()) + taskRepo := postgres.NewScheduledTaskRepository(db.Pool()) + + // Initialize agent service (used by scheduler for headless execution) + agentService := agent.NewAgentService( + aiClient, msgRepo, convRepo, memRepo, taskRepo, + redisClient, verifierClient, pluginService, mcpProvider, nil, + logger, cfg.AI.SummaryModel, cfg.Context, cfg.Scheduler, + ) + + // Initialize and run scheduler + sched := scheduler.New( + agentService, taskRepo, convRepo, msgRepo, mcpProvider, + cfg.Scheduler, cfg.AI.Model, logger, + ) + + // Handle graceful shutdown + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + go func() { + <-quit + logger.Info("shutting down scheduler") + cancel() + }() + + if err := sched.Run(ctx); err != nil && err != context.Canceled { + logger.WithError(err).Fatal("scheduler error") + } + + logger.Info("scheduler stopped") +} diff --git a/cmd/server/main.go b/cmd/server/main.go index cae47b6..e90099c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -13,12 +13,13 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/sirupsen/logrus" - "github.com/vultisig/agent-backend/internal/ai/anthropic" + "github.com/vultisig/agent-backend/internal/ai" "github.com/vultisig/agent-backend/internal/api" "github.com/vultisig/agent-backend/internal/cache/redis" "github.com/vultisig/agent-backend/internal/config" - "github.com/vultisig/agent-backend/internal/service" + "github.com/vultisig/agent-backend/internal/mcp" "github.com/vultisig/agent-backend/internal/service/agent" + mcpclient "github.com/vultisig/agent-backend/internal/service/mcp" "github.com/vultisig/agent-backend/internal/service/plugin" "github.com/vultisig/agent-backend/internal/service/verifier" "github.com/vultisig/agent-backend/internal/storage/postgres" @@ -58,11 +59,8 @@ func main() { } defer redisClient.Close() - // Initialize Anthropic client - anthropicClient := anthropic.NewClient(cfg.Anthropic.APIKey, cfg.Anthropic.Model) - - // Initialize services - authService := service.NewAuthService(cfg.Server.JWTSecret) + // Initialize AI client + aiClient := ai.NewClient(cfg.AI.APIKey, cfg.AI.Model, cfg.AI.BaseURL, cfg.AI.AppName, cfg.AI.AppURL) // Initialize plugin service (skills fetched dynamically on demand) pluginService := plugin.NewService(cfg.Verifier.URL, redisClient, logger) @@ -75,11 +73,62 @@ func main() { msgRepo := postgres.NewMessageRepository(db.Pool()) memRepo := postgres.NewMemoryRepository(db.Pool()) + // Initialize MCP JSON-RPC client for tool discovery and vault operations (optional) + var mcpProvider agent.MCPToolProvider + if cfg.MCP.ServerURL != "" { + cacheTTL := time.Duration(cfg.MCP.ToolCacheTTLSec) * time.Second + mcpClient := mcp.NewClient(cfg.MCP.ServerURL, cacheTTL, logger) + + mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second) + defer mcpCancel() + + if err := mcpClient.Initialize(mcpCtx); err != nil { + logger.WithError(err).Warn("failed to initialize mcp client, continuing without mcp tools") + } else { + tools, err := mcpClient.ListTools(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp tools, continuing without mcp tools") + } else { + logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") + mcpProvider = mcpClient + } + + // Pre-warm skill cache (non-fatal) + skills, err := mcpClient.ListSkills(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp skills, continuing without skills") + } else { + logger.WithField("skill_count", len(skills)).Info("mcp skills loaded") + } + } + } + + // Initialize MCP swap tx builder (optional) + var swapTxBuilder agent.SwapTxBuilder + if cfg.MCP.URL != "" { + mcpCl := mcpclient.NewClient(cfg.MCP.URL) + swapTxBuilder = mcpclient.NewSwapTxAdapter(mcpCl) + logger.WithField("url", cfg.MCP.URL).Info("MCP swap tx builder enabled") + } + + // Initialize scheduled task repository + taskRepo := postgres.NewScheduledTaskRepository(db.Pool()) + // Initialize agent service - agentService := agent.NewAgentService(anthropicClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, logger, cfg.Anthropic.SummaryModel, cfg.Context) + agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, taskRepo, redisClient, verifierClient, pluginService, mcpProvider, swapTxBuilder, logger, cfg.AI.SummaryModel, cfg.Context, cfg.Scheduler) // Initialize API server - server := api.NewServer(authService, convRepo, agentService, logger) + server := api.NewServer( + verifierClient, + redisClient, + convRepo, + agentService, + logger, + api.AuthCacheConfig{ + KeySecret: cfg.AuthCache.KeySecret, + TTL: time.Duration(cfg.AuthCache.TTLSeconds) * time.Second, + }, + ) // Create Echo server e := echo.New() @@ -88,7 +137,16 @@ func main() { // Add middleware e.Use(middleware.Recover()) + e.Use(middleware.BodyLimit("2M")) e.Use(middleware.CORS()) + limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{ + Rate: 5, + Burst: 30, + ExpiresIn: 5 * time.Minute, + }, + ) + e.Use(middleware.RateLimiter(limiterStore)) e.Use(middleware.RequestID()) e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ LogURI: true, @@ -113,12 +171,14 @@ func main() { }) // Agent routes (authenticated) - agent := e.Group("/agent", server.AuthMiddleware) - agent.POST("/conversations", server.CreateConversation) - agent.POST("/conversations/list", server.ListConversations) - agent.POST("/conversations/:id", server.GetConversation) - agent.DELETE("/conversations/:id", server.DeleteConversation) - agent.POST("/conversations/:id/messages", server.SendMessage) + agentGroup := e.Group("/agent", server.AuthMiddleware) + agentGroup.POST("/conversations", server.CreateConversation) + agentGroup.POST("/conversations/list", server.ListConversations) + agentGroup.POST("/conversations/:id", server.GetConversation) + agentGroup.DELETE("/conversations/:id", server.DeleteConversation) + agentGroup.POST("/conversations/:id/messages", server.SendMessage, api.RateLimitMiddleware(redisClient)) + agentGroup.POST("/conversations/:id/tx/build", server.BuildTx) + agentGroup.POST("/starters", server.GetStarters) // Start server addr := fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port) diff --git a/deploy/01_server.yaml b/deploy/01_server.yaml index 0b0de3a..3889aa5 100644 --- a/deploy/01_server.yaml +++ b/deploy/01_server.yaml @@ -16,11 +16,9 @@ spec: annotations: prometheus.io/scrape: "false" spec: - imagePullSecrets: - - name: ghcr containers: - name: server - image: busybox + image: ghcr.io/vultisig/agent-backend/server:v1.0.2 command: ["/app/main"] ports: - containerPort: 80 @@ -33,17 +31,27 @@ spec: value: "80" - name: LOG_FORMAT value: "json" - # --- Anthropic config (from ConfigMap) --- - - name: ANTHROPIC_MODEL + # --- AI config (from ConfigMap) --- + - name: AI_MODEL valueFrom: configMapKeyRef: name: agent - key: anthropic-model - - name: ANTHROPIC_SUMMARY_MODEL + key: ai-model + - name: AI_SUMMARY_MODEL valueFrom: configMapKeyRef: name: agent - key: anthropic-summary-model + key: ai-summary-model + - name: AI_BASE_URL + valueFrom: + configMapKeyRef: + name: agent + key: ai-base-url + - name: AI_APP_NAME + valueFrom: + configMapKeyRef: + name: agent + key: ai-app-name # --- Context window config (from ConfigMap) --- - name: CONTEXT_WINDOW_SIZE valueFrom: @@ -66,12 +74,13 @@ spec: configMapKeyRef: name: verifier key: url - # --- Secrets --- - - name: JWT_SECRET + # --- MCP config (from ConfigMap) --- + - name: MCP_SERVER_URL valueFrom: - secretKeyRef: - name: jwt - key: secret + configMapKeyRef: + name: agent + key: mcp-server-url + # --- Secrets --- - name: DATABASE_DSN valueFrom: secretKeyRef: @@ -82,11 +91,16 @@ spec: secretKeyRef: name: redis key: uri - - name: ANTHROPIC_API_KEY + - name: AI_API_KEY valueFrom: secretKeyRef: - name: anthropic + name: ai-provider key: api-key + - name: AUTH_CACHE_KEY_SECRET + valueFrom: + secretKeyRef: + name: auth-cache + key: key-secret resources: requests: memory: "64Mi" diff --git a/deploy/02_scheduler.yaml b/deploy/02_scheduler.yaml new file mode 100644 index 0000000..4b19320 --- /dev/null +++ b/deploy/02_scheduler.yaml @@ -0,0 +1,110 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: scheduler + labels: + app: scheduler +spec: + replicas: 1 + selector: + matchLabels: + app: scheduler + template: + metadata: + labels: + app: scheduler + annotations: + prometheus.io/scrape: "false" + spec: + containers: + - name: scheduler + image: ghcr.io/vultisig/agent-backend/scheduler:v1.0.0 + command: ["/app/main"] + env: + # --- Logging --- + - name: LOG_FORMAT + value: "json" + # --- AI config (from ConfigMap) --- + - name: AI_MODEL + valueFrom: + configMapKeyRef: + name: agent + key: ai-model + - name: AI_SUMMARY_MODEL + valueFrom: + configMapKeyRef: + name: agent + key: ai-summary-model + - name: AI_BASE_URL + valueFrom: + configMapKeyRef: + name: agent + key: ai-base-url + - name: AI_APP_NAME + valueFrom: + configMapKeyRef: + name: agent + key: ai-app-name + # --- Context window config (from ConfigMap) --- + - name: CONTEXT_WINDOW_SIZE + valueFrom: + configMapKeyRef: + name: agent + key: context-window-size + - name: CONTEXT_SUMMARIZE_TRIGGER + valueFrom: + configMapKeyRef: + name: agent + key: context-summarize-trigger + - name: CONTEXT_SUMMARY_MAX_TOKENS + valueFrom: + configMapKeyRef: + name: agent + key: context-summary-max-tokens + # --- Verifier URL (from ConfigMap, differs dev/prod) --- + - name: VERIFIER_URL + valueFrom: + configMapKeyRef: + name: verifier + key: url + # --- MCP config (from ConfigMap) --- + - name: MCP_SERVER_URL + valueFrom: + configMapKeyRef: + name: agent + key: mcp-server-url + # --- Scheduler config --- + - name: SCHEDULER_POLL_INTERVAL_SECONDS + value: "30" + - name: SCHEDULER_MAX_ACTIVE_PER_USER + value: "10" + - name: SCHEDULER_MIN_INTERVAL_MINUTES + value: "60" + # --- Secrets --- + - name: DATABASE_DSN + valueFrom: + secretKeyRef: + name: postgres + key: dsn + - name: REDIS_URI + valueFrom: + secretKeyRef: + name: redis + key: uri + - name: AI_API_KEY + valueFrom: + secretKeyRef: + name: ai-provider + key: api-key + - name: AUTH_CACHE_KEY_SECRET + valueFrom: + secretKeyRef: + name: auth-cache + key: key-secret + resources: + requests: + memory: "64Mi" + cpu: "50m" + limits: + memory: "256Mi" + cpu: "250m" diff --git a/deploy/dev/01_agent.yaml b/deploy/dev/01_agent.yaml index 93dbfd9..abe2f36 100644 --- a/deploy/dev/01_agent.yaml +++ b/deploy/dev/01_agent.yaml @@ -3,8 +3,11 @@ kind: ConfigMap metadata: name: agent data: - anthropic-model: "claude-sonnet-4-20250514" - anthropic-summary-model: "claude-haiku-4-5-20251001" + ai-model: "anthropic/claude-sonnet-4.5" + ai-summary-model: "anthropic/claude-haiku-4.5" + ai-base-url: "https://openrouter.ai/api/v1" + ai-app-name: "vultisig-agent" context-window-size: "20" context-summarize-trigger: "30" context-summary-max-tokens: "512" + mcp-server-url: "" # TODO: set MCP server URL for dev environment diff --git a/deploy/dev/02_ingress.yaml b/deploy/dev/02_ingress.yaml index a1f9011..b18af4d 100644 --- a/deploy/dev/02_ingress.yaml +++ b/deploy/dev/02_ingress.yaml @@ -1,7 +1,7 @@ apiVersion: networking.k8s.io/v1 kind: Ingress metadata: - name: agent + name: agent-backend annotations: traefik.ingress.kubernetes.io/router.entrypoints: websecure cert-manager.io/cluster-issuer: "letsencrypt" @@ -9,10 +9,10 @@ spec: ingressClassName: traefik tls: - hosts: - - agent.dev.plugins.vultisig.com + - agent-backend.dev.plugins.vultisig.com secretName: agent-tls rules: - - host: agent.dev.plugins.vultisig.com + - host: agent-backend.dev.plugins.vultisig.com http: paths: - path: / diff --git a/deploy/dev/03_secrets.yaml b/deploy/dev/03_secrets.yaml deleted file mode 100644 index 16afc8b..0000000 --- a/deploy/dev/03_secrets.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# TEMPLATE ONLY — replace values before applying. NEVER commit real secrets. -apiVersion: v1 -kind: Secret -metadata: - name: anthropic -type: Opaque -stringData: - api-key: "sk-ant-REPLACE-ME" diff --git a/deploy/prod/01_agent.yaml b/deploy/prod/01_agent.yaml index 93dbfd9..947bc04 100644 --- a/deploy/prod/01_agent.yaml +++ b/deploy/prod/01_agent.yaml @@ -3,8 +3,11 @@ kind: ConfigMap metadata: name: agent data: - anthropic-model: "claude-sonnet-4-20250514" - anthropic-summary-model: "claude-haiku-4-5-20251001" + ai-model: "anthropic/claude-sonnet-4.5" + ai-summary-model: "anthropic/claude-haiku-4.5" + ai-base-url: "https://openrouter.ai/api/v1" + ai-app-name: "vultisig-agent" context-window-size: "20" context-summarize-trigger: "30" context-summary-max-tokens: "512" + mcp-server-url: "" # TODO: set MCP server URL for prod environment diff --git a/deploy/prod/02_ingress.yaml b/deploy/prod/02_ingress.yaml index e39e41d..cd7c060 100644 --- a/deploy/prod/02_ingress.yaml +++ b/deploy/prod/02_ingress.yaml @@ -1,18 +1,17 @@ apiVersion: networking.k8s.io/v1 kind: Ingress metadata: - name: agent + name: agent-backend annotations: traefik.ingress.kubernetes.io/router.entrypoints: websecure - cert-manager.io/cluster-issuer: "letsencrypt" spec: ingressClassName: traefik tls: - hosts: - - agent.prod.plugins.vultisig.com - secretName: agent-tls + - agent.vultisig.com + secretName: cf-agent-tls rules: - - host: agent.prod.plugins.vultisig.com + - host: agent.vultisig.com http: paths: - path: / diff --git a/deploy/prod/03_secrets.yaml b/deploy/prod/03_secrets.yaml deleted file mode 100644 index 16afc8b..0000000 --- a/deploy/prod/03_secrets.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# TEMPLATE ONLY — replace values before applying. NEVER commit real secrets. -apiVersion: v1 -kind: Secret -metadata: - name: anthropic -type: Opaque -stringData: - api-key: "sk-ant-REPLACE-ME" diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..532b695 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,76 @@ +services: + postgres: + image: postgres:17-alpine + environment: + POSTGRES_USER: vultisig + POSTGRES_PASSWORD: vultisig + POSTGRES_DB: vultisig-agent + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U vultisig -d vultisig-agent"] + interval: 3s + timeout: 3s + retries: 10 + + redis: + image: redis:7-alpine + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 3s + timeout: 3s + retries: 10 + + server: + build: + context: . + args: + SERVICE: server + ports: + - "8084:8084" + env_file: + - .env + environment: + SERVER_PORT: "8084" + DATABASE_DSN: postgres://vultisig:vultisig@postgres:5432/vultisig-agent?sslmode=disable + REDIS_URI: redis://redis:6379 + VERIFIER_URL: http://host.docker.internal:8080 + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + develop: + watch: + - action: rebuild + path: . + ignore: + - docker-compose.yaml + - .env + + scheduler: + build: + context: . + args: + SERVICE: scheduler + env_file: + - .env + environment: + DATABASE_DSN: postgres://vultisig:vultisig@postgres:5432/vultisig-agent?sslmode=disable + REDIS_URI: redis://redis:6379 + VERIFIER_URL: http://host.docker.internal:8080 + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + develop: + watch: + - action: rebuild + path: . + ignore: + - docker-compose.yaml + - .env + +volumes: + pgdata: diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..ebd3112 --- /dev/null +++ b/flake.lock @@ -0,0 +1,306 @@ +{ + "nodes": { + "cachix": { + "inputs": { + "devenv": [ + "devenv" + ], + "flake-compat": [ + "devenv", + "flake-compat" + ], + "git-hooks": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760971495, + "narHash": "sha256-IwnNtbNVrlZIHh7h4Wz6VP0Furxg9Hh0ycighvL5cZc=", + "owner": "cachix", + "repo": "cachix", + "rev": "c5bfd933d1033672f51a863c47303fc0e093c2d2", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "latest", + "repo": "cachix", + "type": "github" + } + }, + "devenv": { + "inputs": { + "cachix": "cachix", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "git-hooks": "git-hooks", + "nix": "nix", + "nixd": "nixd", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1771419682, + "narHash": "sha256-NAemVgEJeZjGl3+438M4rUL8ms9QdDFMYthU12F70FQ=", + "owner": "cachix", + "repo": "devenv", + "rev": "f77fc4de35c184d9ef9a32d5d7e9033351bcdfdc", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760948891, + "narHash": "sha256-TmWcdiUUaWk8J4lpjzu4gCGxWY6/Ok7mOK4fIFfBuU4=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "864599284fc7c0ba6357ed89ed5e2cd5040f0c04", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-root": { + "locked": { + "lastModified": 1723604017, + "narHash": "sha256-rBtQ8gg+Dn4Sx/s+pvjdq3CB2wQNzx9XGFq/JVGCB6k=", + "owner": "srid", + "repo": "flake-root", + "rev": "b759a56851e10cb13f6b8e5698af7b59c44be26e", + "type": "github" + }, + "original": { + "owner": "srid", + "repo": "flake-root", + "type": "github" + } + }, + "git-hooks": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1760663237, + "narHash": "sha256-BflA6U4AM1bzuRMR8QqzPXqh8sWVCNDzOdsxXEguJIc=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "ca5b894d3e3e151ffc1db040b6ce4dcc75d31c37", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "git-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nix": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-parts": [ + "devenv", + "flake-parts" + ], + "git-hooks-nix": [ + "devenv", + "git-hooks" + ], + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-23-11": [ + "devenv" + ], + "nixpkgs-regression": [ + "devenv" + ] + }, + "locked": { + "lastModified": 1770395975, + "narHash": "sha256-zg0AEZn8d4rqIIsw5XrkVL5p1y6fBj2L57awfUg+gNA=", + "owner": "cachix", + "repo": "nix", + "rev": "ccb6019ce2bd11f5de5fe4617c0079d8cb1ed057", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "devenv-2.32", + "repo": "nix", + "type": "github" + } + }, + "nixd": { + "inputs": { + "flake-parts": [ + "devenv", + "flake-parts" + ], + "flake-root": "flake-root", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "treefmt-nix": "treefmt-nix" + }, + "locked": { + "lastModified": 1763964548, + "narHash": "sha256-JTRoaEWvPsVIMFJWeS4G2isPo15wqXY/otsiHPN0zww=", + "owner": "nix-community", + "repo": "nixd", + "rev": "d4bf15e56540422e2acc7bc26b20b0a0934e3f5e", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nixd", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1761313199, + "narHash": "sha256-wCIACXbNtXAlwvQUo1Ed++loFALPjYUA3dpcUJiXO44=", + "owner": "cachix", + "repo": "devenv-nixpkgs", + "rev": "d1c30452ebecfc55185ae6d1c983c09da0c274ff", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "rolling", + "repo": "devenv-nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1771177547, + "narHash": "sha256-trTtk3WTOHz7hSw89xIIvahkgoFJYQ0G43IlqprFoMA=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ac055f38c798b0d87695240c7b761b82fc7e5bc2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "nixpkgs": "nixpkgs_2", + "systems": "systems" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "treefmt-nix": { + "inputs": { + "nixpkgs": [ + "devenv", + "nixd", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734704479, + "narHash": "sha256-MMi74+WckoyEWBRcg/oaGRvXC9BVVxDZNRMpL+72wBI=", + "owner": "numtide", + "repo": "treefmt-nix", + "rev": "65712f5af67234dad91a5a4baee986a8b62dbf8f", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "treefmt-nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..502e258 --- /dev/null +++ b/flake.nix @@ -0,0 +1,42 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + systems.url = "github:nix-systems/default"; + devenv.url = "github:cachix/devenv"; + }; + + outputs = { self, nixpkgs, devenv, systems, ... } @ inputs: + let + forEachSystem = nixpkgs.lib.genAttrs (import systems); + in + { + devShells = forEachSystem + (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + default = devenv.lib.mkShell { + inherit inputs pkgs; + modules = [ + { + languages.go = { + enable = true; + }; + + packages = with pkgs; [ + postgresql + go-ethereum + sqlc + redis + ]; + + enterShell = '' + echo "agent-backend shell started!" + ''; + } + ]; + }; + }); + }; +} diff --git a/go.mod b/go.mod index aba97db..75faa8e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/vultisig/agent-backend go 1.25 require ( - github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.5 github.com/kelseyhightower/envconfig v1.4.0 diff --git a/go.sum b/go.sum index 82162a0..60b6d20 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= -github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= diff --git a/internal/ai/anthropic/client.go b/internal/ai/anthropic/client.go deleted file mode 100644 index a9b3e61..0000000 --- a/internal/ai/anthropic/client.go +++ /dev/null @@ -1,175 +0,0 @@ -package anthropic - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "time" -) - -const ( - defaultBaseURL = "https://api.anthropic.com/v1" - defaultMaxTokens = 4096 - apiVersion = "2023-06-01" -) - -// Client is an Anthropic Claude API client. -type Client struct { - apiKey string - model string - httpClient *http.Client - baseURL string -} - -// Message represents a simple conversation message with string content. -type Message struct { - Role string `json:"role"` // "user" or "assistant" - Content string `json:"content"` -} - -// AssistantMessage represents an assistant response with tool_use blocks. -// Used to replay assistant tool calls in the conversation history. -type AssistantMessage struct { - Role string `json:"role"` // "assistant" - Content []ContentBlock `json:"content"` -} - -// ToolResultMessage represents a user message containing tool results. -type ToolResultMessage struct { - Role string `json:"role"` // "user" - Content []ToolResultBlock `json:"content"` -} - -// ToolResultBlock is a single tool result in a ToolResultMessage. -type ToolResultBlock struct { - Type string `json:"type"` // "tool_result" - ToolUseID string `json:"tool_use_id"` - Content string `json:"content"` - IsError bool `json:"is_error,omitempty"` -} - -// Tool represents a tool that Claude can use. -type Tool struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema any `json:"input_schema"` -} - -// ToolChoice specifies how Claude should use tools. -type ToolChoice struct { - Type string `json:"type"` // "auto", "any", or "tool" - Name string `json:"name,omitempty"` // Required when type is "tool" -} - -// Request is the request body for the messages API. -// Messages accepts []any to support Message, AssistantMessage, and ToolResultMessage types. -type Request struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - System string `json:"system,omitempty"` - Messages []any `json:"messages"` - Tools []Tool `json:"tools,omitempty"` - ToolChoice *ToolChoice `json:"tool_choice,omitempty"` -} - -// Response is the response from the messages API. -type Response struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []ContentBlock `json:"content"` - StopReason string `json:"stop_reason"` - Usage Usage `json:"usage"` -} - -// ContentBlock represents a content block in the response. -type ContentBlock struct { - Type string `json:"type"` // "text" or "tool_use" - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input json.RawMessage `json:"input,omitempty"` -} - -// Usage contains token usage information. -type Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -// APIError represents an error from the Anthropic API. -type APIError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -func (e *APIError) Error() string { - return fmt.Sprintf("anthropic: %s: %s", e.Type, e.Message) -} - -// NewClient creates a new Anthropic client. -func NewClient(apiKey, model string) *Client { - return &Client{ - apiKey: apiKey, - model: model, - baseURL: defaultBaseURL, - httpClient: &http.Client{ - Timeout: 60 * time.Second, - }, - } -} - -// SendMessage sends a message to Claude and returns the response. -func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, error) { - if req.Model == "" { - req.Model = c.model - } - if req.MaxTokens == 0 { - req.MaxTokens = defaultMaxTokens - } - - body, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("marshal request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/messages", bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("x-api-key", c.apiKey) - httpReq.Header.Set("anthropic-version", apiVersion) - - resp, err := c.httpClient.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("send request: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - var apiErr struct { - Error APIError `json:"error"` - } - if err := json.Unmarshal(respBody, &apiErr); err != nil { - return nil, fmt.Errorf("anthropic: status %d: %s", resp.StatusCode, string(respBody)) - } - return nil, &apiErr.Error - } - - var result Response - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("unmarshal response: %w", err) - } - - return &result, nil -} diff --git a/internal/ai/client.go b/internal/ai/client.go new file mode 100644 index 0000000..486e84b --- /dev/null +++ b/internal/ai/client.go @@ -0,0 +1,532 @@ +package ai + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + defaultMaxTokens = 4096 + maxRetries = 3 + baseRetryDelay = 1 * time.Second +) + +// Client is an OpenRouter-compatible AI API client using OpenAI Chat Completions format. +type Client struct { + apiKey string + model string + httpClient *http.Client + baseURL string + appName string + appURL string +} + +// Message represents a simple conversation message with string content. +type Message struct { + Role string `json:"role"` // "user", "assistant", or "system" + Content string `json:"content"` +} + +// AssistantMessage represents an assistant response with optional tool calls. +// Used to replay assistant tool calls in the conversation history. +type AssistantMessage struct { + Role string `json:"role"` // "assistant" + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// ToolMessage represents a single tool result in OpenAI format. +type ToolMessage struct { + Role string `json:"role"` // "tool" + ToolCallID string `json:"tool_call_id"` + Content string `json:"content"` +} + +// ToolCall represents a tool invocation from the assistant. +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` // "function" + Function FunctionCall `json:"function"` +} + +// FunctionCall contains the function name and arguments. +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +// Tool represents a tool that the model can use. +// InputSchema is kept as the internal field name; the client wraps it in the +// OpenAI function format ({"type":"function","function":{...,"parameters":...}}) +// when building the wire request. +type Tool struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema any `json:"input_schema"` +} + +// ToolChoice specifies how the model should use tools. +type ToolChoice struct { + Type string // "auto", "none", or "function" (specific tool) + Name string // required when Type is "function" +} + +// MarshalJSON implements custom marshaling for ToolChoice. +// "auto" and "none" marshal as the bare string; a specific tool marshals as +// {"type":"function","function":{"name":"..."}}. +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + switch tc.Type { + case "auto", "none": + return json.Marshal(tc.Type) + default: + return json.Marshal(struct { + Type string `json:"type"` + Function struct { + Name string `json:"name"` + } `json:"function"` + }{ + Type: "function", + Function: struct { + Name string `json:"name"` + }{Name: tc.Name}, + }) + } +} + +// Request is the request payload for the AI API. +// The client translates this into the OpenAI Chat Completions wire format. +type Request struct { + Model string + MaxTokens int + System string // prepended as a system message + Messages []any // Message, AssistantMessage, ToolMessage + Tools []Tool + ToolChoice *ToolChoice +} + +// Response is the parsed response from the AI API. +type Response struct { + Content string + ToolCalls []ToolCall + FinishReason string + Usage Usage +} + +// Usage contains token usage information. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// APIError represents an error from the AI API. +type APIError struct { + Type string `json:"type"` + Message string `json:"message"` + StatusCode int `json:"-"` +} + +func (e *APIError) Error() string { + return fmt.Sprintf("ai: %s: %s", e.Type, e.Message) +} + +// --- Wire format types (internal) --- + +type openAIFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` +} + +type openAITool struct { + Type string `json:"type"` // "function" + Function openAIFunction `json:"function"` +} + +type openAIRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []any `json:"messages"` + Tools []openAITool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type openAIResponse struct { + ID string `json:"id"` + Choices []openAIChoice `json:"choices"` + Usage Usage `json:"usage"` +} + +type openAIChoice struct { + Message openAIMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type openAIMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +// --- Client --- + +// NewClient creates a new AI client. +func NewClient(apiKey, model, baseURL, appName, appURL string) *Client { + return &Client{ + apiKey: apiKey, + model: model, + baseURL: baseURL, + appName: appName, + appURL: appURL, + httpClient: &http.Client{ + Timeout: 90 * time.Second, + }, + } +} + +// buildWireRequest converts the public Request into the OpenAI wire format. +func (c *Client) buildWireRequest(req *Request, stream bool) ([]byte, error) { + if req.Model == "" { + req.Model = c.model + } + if req.MaxTokens == 0 { + req.MaxTokens = defaultMaxTokens + } + + messages := make([]any, 0, len(req.Messages)+1) + if req.System != "" { + messages = append(messages, Message{Role: "system", Content: req.System}) + } + messages = append(messages, req.Messages...) + + var tools []openAITool + for _, t := range req.Tools { + tools = append(tools, openAITool{ + Type: "function", + Function: openAIFunction{ + Name: t.Name, + Description: t.Description, + Parameters: t.InputSchema, + }, + }) + } + + var toolChoice any + if req.ToolChoice != nil { + toolChoice = req.ToolChoice + } + + wireReq := openAIRequest{ + Model: req.Model, + MaxTokens: req.MaxTokens, + Messages: messages, + Tools: tools, + ToolChoice: toolChoice, + Stream: stream, + } + + return json.Marshal(wireReq) +} + +// SendMessage sends a message to the model and returns the response. +func (c *Client) SendMessage(ctx context.Context, req *Request) (*Response, error) { + body, err := c.buildWireRequest(req, false) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + var lastErr error + for attempt := range maxRetries { + result, err := c.doRequest(ctx, body) + if err == nil { + return result, nil + } + + apiErr, ok := err.(*APIError) + if !ok || !isRetryable(apiErr.StatusCode) { + return nil, err + } + + lastErr = err + delay := retryDelay(apiErr.StatusCode, attempt) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + } + + return nil, fmt.Errorf("max retries exceeded: %w", lastErr) +} + +func (c *Client) doRequest(ctx context.Context, body []byte) (*Response, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + if c.appName != "" { + httpReq.Header.Set("X-Title", c.appName) + } + if c.appURL != "" { + httpReq.Header.Set("HTTP-Referer", c.appURL) + } + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var parsed struct { + Error APIError `json:"error"` + } + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, &APIError{ + Type: "unknown", + Message: fmt.Sprintf("status %d: %s", resp.StatusCode, string(respBody)), + StatusCode: resp.StatusCode, + } + } + parsed.Error.StatusCode = resp.StatusCode + + if resp.StatusCode == http.StatusTooManyRequests { + if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" { + parsed.Error.Message += fmt.Sprintf(" (retry-after: %s)", retryAfter) + } + } + + return nil, &parsed.Error + } + + var oaiResp openAIResponse + if err := json.Unmarshal(respBody, &oaiResp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + result := &Response{Usage: oaiResp.Usage} + if len(oaiResp.Choices) > 0 { + choice := oaiResp.Choices[0] + result.Content = choice.Message.Content + result.ToolCalls = choice.Message.ToolCalls + result.FinishReason = choice.FinishReason + } + + return result, nil +} + +func isRetryable(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode >= 500 +} + +func retryDelay(_ int, attempt int) time.Duration { + return baseRetryDelay * time.Duration(1< 0 { + callback(cbDelta) + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read stream: %w", err) + } + + // Collect accumulated tool calls in index order + for i := 0; i < len(toolCallsMap); i++ { + if tc, ok := toolCallsMap[i]; ok { + result.ToolCalls = append(result.ToolCalls, *tc) + } + } + + return &result, nil +} diff --git a/internal/ai/stream_parser.go b/internal/ai/stream_parser.go new file mode 100644 index 0000000..0f104ef --- /dev/null +++ b/internal/ai/stream_parser.go @@ -0,0 +1,113 @@ +package ai + +import ( + "strings" + "unicode/utf8" +) + +type extractorState int + +const ( + stateSearchKey extractorState = iota + stateInValue +) + +type ResponseFieldExtractor struct { + buf strings.Builder + state extractorState + esc bool +} + +func NewResponseFieldExtractor() *ResponseFieldExtractor { + return &ResponseFieldExtractor{state: stateSearchKey} +} + +func (e *ResponseFieldExtractor) Feed(delta string) string { + var out strings.Builder + + for i := 0; i < len(delta); { + r, size := utf8.DecodeRuneInString(delta[i:]) + i += size + + switch e.state { + case stateSearchKey: + e.buf.WriteRune(r) + s := e.buf.String() + const needle = `"response":"` + if strings.HasSuffix(s, needle) { + e.state = stateInValue + e.buf.Reset() + } + if len(s) > 256 { + trimmed := s[len(s)-len(needle):] + e.buf.Reset() + e.buf.WriteString(trimmed) + } + + case stateInValue: + if e.esc { + e.esc = false + switch r { + case '"': + out.WriteRune('"') + case '\\': + out.WriteRune('\\') + case 'n': + out.WriteRune('\n') + case 'r': + out.WriteRune('\r') + case 't': + out.WriteRune('\t') + case '/': + out.WriteRune('/') + case 'u': + if i+4 <= len(delta) { + hex := delta[i : i+4] + codepoint := parseHex4(hex) + if codepoint >= 0 { + out.WriteRune(rune(codepoint)) + i += 4 + } + } + default: + out.WriteRune('\\') + out.WriteRune(r) + } + continue + } + + switch r { + case '\\': + e.esc = true + case '"': + e.state = stateSearchKey + e.buf.Reset() + default: + out.WriteRune(r) + } + } + } + + return out.String() +} + +func parseHex4(s string) int { + if len(s) != 4 { + return -1 + } + var val int + for _, c := range s { + val <<= 4 + switch { + case c >= '0' && c <= '9': + val |= int(c - '0') + case c >= 'a' && c <= 'f': + val |= int(c-'a') + 10 + case c >= 'A' && c <= 'F': + val |= int(c-'A') + 10 + default: + return -1 + } + } + return val +} diff --git a/internal/ai/stream_parser_test.go b/internal/ai/stream_parser_test.go new file mode 100644 index 0000000..66bf76b --- /dev/null +++ b/internal/ai/stream_parser_test.go @@ -0,0 +1,111 @@ +package ai + +import ( + "testing" +) + +func TestResponseFieldExtractor_BasicResponse(t *testing.T) { + e := NewResponseFieldExtractor() + + chunks := []string{ + `{"intent":"action_request","conversation_title":"Hello","`, + `response":"`, + `Hello `, + `world`, + `","suggestions":[]}`, + } + + var got string + for _, c := range chunks { + got += e.Feed(c) + } + + if got != "Hello world" { + t.Errorf("got %q, want %q", got, "Hello world") + } +} + +func TestResponseFieldExtractor_Escapes(t *testing.T) { + e := NewResponseFieldExtractor() + + chunks := []string{ + `{"response":"line1\nline2\\backslash\"quote"}`, + } + + var got string + for _, c := range chunks { + got += e.Feed(c) + } + + want := "line1\nline2\\backslash\"quote" + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestResponseFieldExtractor_UnicodeEscape(t *testing.T) { + e := NewResponseFieldExtractor() + + var got string + got += e.Feed(`{"response":"hello \u0048\u0069"}`) + + if got != "hello Hi" { + t.Errorf("got %q, want %q", got, "hello Hi") + } +} + +func TestResponseFieldExtractor_SmallChunks(t *testing.T) { + e := NewResponseFieldExtractor() + + input := `{"response":"abc"}` + var got string + for _, ch := range input { + got += e.Feed(string(ch)) + } + + if got != "abc" { + t.Errorf("got %q, want %q", got, "abc") + } +} + +func TestResponseFieldExtractor_ResponseFieldNotFirst(t *testing.T) { + e := NewResponseFieldExtractor() + + var got string + got += e.Feed(`{"intent":"general","conversation_title":"Test Title","response":"the answer","suggestions":[]}`) + + if got != "the answer" { + t.Errorf("got %q, want %q", got, "the answer") + } +} + +func TestResponseFieldExtractor_EmptyResponse(t *testing.T) { + e := NewResponseFieldExtractor() + + var got string + got += e.Feed(`{"response":""}`) + + if got != "" { + t.Errorf("got %q, want %q", got, "") + } +} + +func TestResponseFieldExtractor_MarkdownContent(t *testing.T) { + e := NewResponseFieldExtractor() + + chunks := []string{ + `{"response":"# Title\n\n- item 1\n- `, + `item 2\n\n**bold** and `, + `*italic*"}`, + } + + var got string + for _, c := range chunks { + got += e.Feed(c) + } + + want := "# Title\n\n- item 1\n- item 2\n\n**bold** and *italic*" + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/internal/api/message.go b/internal/api/message.go index e65d5c9..5af38c4 100644 --- a/internal/api/message.go +++ b/internal/api/message.go @@ -1,8 +1,11 @@ package api import ( + "encoding/json" "errors" + "fmt" "net/http" + "strings" "github.com/google/uuid" "github.com/labstack/echo/v4" @@ -11,36 +14,35 @@ import ( "github.com/vultisig/agent-backend/internal/storage/postgres" ) -// SendMessage handles POST /agent/conversations/:id/messages func (s *Server) SendMessage(c echo.Context) error { - // 1. Parse conversation ID from :id param idStr := c.Param("id") convID, err := uuid.Parse(idStr) if err != nil { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid conversation id"}) } - // 2. Bind request body var req agent.SendMessageRequest - if err := c.Bind(&req); err != nil { + err = c.Bind(&req) + if err != nil { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) } - // 3. Validate request has content, suggestion selection, or action result if req.Content == "" && req.SelectedSuggestionID == nil && req.ActionResult == nil { return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "content, selected_suggestion_id, or action_result is required"}) } - // 4. Validate public_key matches JWT authPublicKey := GetPublicKey(c) if req.PublicKey != authPublicKey { return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) } - // 5. Pass access token to request for plugin installation checks req.AccessToken = GetAccessToken(c) - // 6. Call agentService.ProcessMessage + if wantsSSE(c.Request()) { + s.logger.Info("SSE streaming requested") + return s.sendMessageSSE(c, convID, &req) + } + resp, err := s.agentService.ProcessMessage(c.Request().Context(), convID, req.PublicKey, &req) if err != nil { if errors.Is(err, postgres.ErrNotFound) || err.Error() == "conversation not found" { @@ -50,6 +52,49 @@ func (s *Server) SendMessage(c echo.Context) error { return c.JSON(http.StatusInternalServerError, ErrorResponse{Error: "failed to process message"}) } - // 6. Return SendMessageResponse return c.JSON(http.StatusOK, resp) } + +func wantsSSE(r *http.Request) bool { + accept := r.Header.Get("Accept") + return strings.Contains(accept, "text/event-stream") +} + +func (s *Server) sendMessageSSE(c echo.Context, convID uuid.UUID, req *agent.SendMessageRequest) error { + w := c.Response() + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.Writer.(http.Flusher) + if !ok { + return c.JSON(http.StatusInternalServerError, ErrorResponse{Error: "streaming not supported"}) + } + flusher.Flush() + + eventCh := make(chan agent.SSEEvent, 32) + + go s.agentService.ProcessMessageStream(c.Request().Context(), convID, req.PublicKey, req, eventCh) + + for ev := range eventCh { + data, err := json.Marshal(ev.Data) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal SSE event data") + continue + } + + _, writeErr := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", ev.Event, data) + if writeErr != nil { + s.logger.WithError(writeErr).Warn("SSE write failed") + return nil + } + flusher.Flush() + if ev.Event == "text_delta" { + s.logger.WithField("len", len(data)).Debug("SSE text_delta flushed") + } + } + + return nil +} diff --git a/internal/api/middleware.go b/internal/api/middleware.go index c82e354..baeb9fd 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -1,13 +1,18 @@ package api import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "net/http" "strings" "github.com/labstack/echo/v4" + goredis "github.com/redis/go-redis/v9" ) -// AuthMiddleware validates JWT tokens and extracts the public key. +// AuthMiddleware validates JWT tokens via the verifier /auth/me endpoint. +// Results are cached in Redis for 5 minutes to avoid repeated round-trips. func (s *Server) AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { authHeader := c.Request().Header.Get(echo.HeaderAuthorization) @@ -19,17 +24,41 @@ func (s *Server) AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { return c.JSON(http.StatusUnauthorized, ErrorResponse{Error: "invalid authorization header"}) } + token := parts[1] - claims, err := s.authService.ValidateToken(parts[1]) - if err != nil { - return c.JSON(http.StatusUnauthorized, ErrorResponse{Error: "invalid token"}) + // Check Redis cache first. + cacheKey := authCacheKey(token, s.authCacheKeySecret) + publicKey, err := s.redis.Get(c.Request().Context(), cacheKey) + if err != nil && err != goredis.Nil { + s.logger.WithError(err).Warn("redis cache lookup failed") } - c.Set("public_key", claims.PublicKey) + // Cache miss — call verifier. + if publicKey == "" { + publicKey, err = s.verifier.GetMe(c.Request().Context(), token) + if err != nil { + s.logger.WithError(err).Debug("verifier auth/me failed") + return c.JSON(http.StatusUnauthorized, ErrorResponse{Error: "invalid token"}) + } + + // Cache the result. + if cacheErr := s.redis.Set(c.Request().Context(), cacheKey, publicKey, s.authCacheTTL); cacheErr != nil { + s.logger.WithError(cacheErr).Warn("failed to cache auth result") + } + } + + c.Set("public_key", publicKey) + c.Set("access_token", token) return next(c) } } +func authCacheKey(token string, secret []byte) string { + mac := hmac.New(sha256.New, secret) + mac.Write([]byte(token)) + return "auth_me:v1:" + hex.EncodeToString(mac.Sum(nil)) +} + // GetPublicKey extracts the public key from the echo context. func GetPublicKey(c echo.Context) string { pk, _ := c.Get("public_key").(string) diff --git a/internal/api/ratelimit.go b/internal/api/ratelimit.go new file mode 100644 index 0000000..4d1d439 --- /dev/null +++ b/internal/api/ratelimit.go @@ -0,0 +1,40 @@ +package api + +import ( + "fmt" + "net/http" + "time" + + "github.com/labstack/echo/v4" + + "github.com/vultisig/agent-backend/internal/cache/redis" +) + +const ( + rateLimitWindow = 1 * time.Minute + rateLimitMax = 10 +) + +func RateLimitMiddleware(redisClient *redis.Client) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + pk := GetPublicKey(c) + if pk == "" { + return next(c) + } + + key := fmt.Sprintf("ratelimit:%s", pk) + + count, err := redisClient.Incr(c.Request().Context(), key, rateLimitWindow) + if err != nil { + return next(c) + } + + if count > rateLimitMax { + return c.JSON(http.StatusTooManyRequests, ErrorResponse{Error: "rate limit exceeded, try again later"}) + } + + return next(c) + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index db6d7f4..d98b76e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1,27 +1,42 @@ package api import ( + "time" + "github.com/sirupsen/logrus" - "github.com/vultisig/agent-backend/internal/service" + "github.com/vultisig/agent-backend/internal/cache/redis" "github.com/vultisig/agent-backend/internal/service/agent" + "github.com/vultisig/agent-backend/internal/service/verifier" "github.com/vultisig/agent-backend/internal/storage/postgres" ) // Server holds API dependencies. type Server struct { - authService *service.AuthService + verifier *verifier.Client + redis *redis.Client convRepo *postgres.ConversationRepository agentService *agent.AgentService + authCacheKeySecret []byte + authCacheTTL time.Duration logger *logrus.Logger } +// AuthCacheConfig contains auth cache keying and TTL settings. +type AuthCacheConfig struct { + KeySecret string + TTL time.Duration +} + // NewServer creates a new API server. -func NewServer(authService *service.AuthService, convRepo *postgres.ConversationRepository, agentService *agent.AgentService, logger *logrus.Logger) *Server { +func NewServer(verifier *verifier.Client, redis *redis.Client, convRepo *postgres.ConversationRepository, agentService *agent.AgentService, logger *logrus.Logger, authCacheCfg AuthCacheConfig) *Server { return &Server{ - authService: authService, + verifier: verifier, + redis: redis, convRepo: convRepo, agentService: agentService, + authCacheKeySecret: []byte(authCacheCfg.KeySecret), + authCacheTTL: authCacheCfg.TTL, logger: logger, } } diff --git a/internal/api/starters.go b/internal/api/starters.go new file mode 100644 index 0000000..9a0fc83 --- /dev/null +++ b/internal/api/starters.go @@ -0,0 +1,25 @@ +package api + +import ( + "net/http" + + "github.com/labstack/echo/v4" + + "github.com/vultisig/agent-backend/internal/service/agent" +) + +func (s *Server) GetStarters(c echo.Context) error { + var req agent.GetStartersRequest + err := c.Bind(&req) + if err != nil { + return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) + } + + authPublicKey := GetPublicKey(c) + if req.PublicKey != authPublicKey { + return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) + } + + resp := s.agentService.GenerateStarters(c.Request().Context(), &req) + return c.JSON(http.StatusOK, resp) +} diff --git a/internal/api/swap.go b/internal/api/swap.go new file mode 100644 index 0000000..330049c --- /dev/null +++ b/internal/api/swap.go @@ -0,0 +1,37 @@ +package api + +import ( + "net/http" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + + "github.com/vultisig/agent-backend/internal/service/agent" +) + +func (s *Server) BuildTx(c echo.Context) error { + idStr := c.Param("id") + convID, err := uuid.Parse(idStr) + if err != nil { + return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid conversation id"}) + } + + var req agent.BuildTxRequest + err = c.Bind(&req) + if err != nil { + return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) + } + + authPublicKey := GetPublicKey(c) + if req.PublicKey != authPublicKey { + return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) + } + + resp, err := s.agentService.BuildTx(c.Request().Context(), convID, &req) + if err != nil { + s.logger.WithError(err).Error("failed to build tx") + return c.JSON(http.StatusInternalServerError, ErrorResponse{Error: "failed to build tx"}) + } + + return c.JSON(http.StatusOK, resp) +} diff --git a/internal/cache/redis/client.go b/internal/cache/redis/client.go index e9e7e73..5df0740 100644 --- a/internal/cache/redis/client.go +++ b/internal/cache/redis/client.go @@ -42,6 +42,18 @@ func (c *Client) Set(ctx context.Context, key string, value string, ttl time.Dur return c.rdb.Set(ctx, key, value, ttl).Err() } +// Incr increments a key and sets expiry on first increment. Returns the new count. +func (c *Client) Incr(ctx context.Context, key string, window time.Duration) (int64, error) { + val, err := c.rdb.Incr(ctx, key).Result() + if err != nil { + return 0, err + } + if val == 1 { + c.rdb.Expire(ctx, key, window) + } + return val, nil +} + // Delete removes a key. func (c *Client) Delete(ctx context.Context, key string) error { return c.rdb.Del(ctx, key).Err() diff --git a/internal/config/config.go b/internal/config/config.go index 46a3584..09bbca1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,16 +10,18 @@ type Config struct { Server ServerConfig Database DatabaseConfig Redis RedisConfig - Anthropic AnthropicConfig + AuthCache AuthCacheConfig + AI AIConfig Context ContextConfig Verifier VerifierConfig + MCP MCPConfig + Scheduler SchedulerConfig } // ServerConfig holds HTTP server configuration. type ServerConfig struct { - Host string `envconfig:"SERVER_HOST" default:"0.0.0.0"` - Port string `envconfig:"SERVER_PORT" default:"8080"` - JWTSecret string `envconfig:"JWT_SECRET" required:"true"` + Host string `envconfig:"SERVER_HOST" default:"0.0.0.0"` + Port string `envconfig:"SERVER_PORT" default:"8080"` } // DatabaseConfig holds PostgreSQL configuration. @@ -32,11 +34,20 @@ type RedisConfig struct { URI string `envconfig:"REDIS_URI" required:"true"` } -// AnthropicConfig holds Anthropic Claude API configuration. -type AnthropicConfig struct { - APIKey string `envconfig:"ANTHROPIC_API_KEY" required:"true"` - Model string `envconfig:"ANTHROPIC_MODEL" default:"claude-sonnet-4-20250514"` - SummaryModel string `envconfig:"ANTHROPIC_SUMMARY_MODEL" default:"claude-haiku-4-5-20251001"` +// AuthCacheConfig holds auth cache settings for verifier /auth/me lookups. +type AuthCacheConfig struct { + KeySecret string `envconfig:"AUTH_CACHE_KEY_SECRET" required:"true"` + TTLSeconds int `envconfig:"AUTH_CACHE_TTL_SECONDS" default:"180"` +} + +// AIConfig holds AI provider configuration (OpenRouter-compatible). +type AIConfig struct { + APIKey string `envconfig:"AI_API_KEY" required:"true"` + Model string `envconfig:"AI_MODEL" default:"anthropic/claude-sonnet-4.5"` + SummaryModel string `envconfig:"AI_SUMMARY_MODEL" default:"anthropic/claude-haiku-4.5"` + BaseURL string `envconfig:"AI_BASE_URL" default:"https://openrouter.ai/api/v1"` + AppName string `envconfig:"AI_APP_NAME" default:"vultisig-agent"` + AppURL string `envconfig:"AI_APP_URL" default:""` } // TODO: Add WhisperConfig for OpenAI Whisper voice transcription support. @@ -53,6 +64,20 @@ type VerifierConfig struct { URL string `envconfig:"VERIFIER_URL" required:"true"` } +// MCPConfig holds MCP (Model Context Protocol) server configuration. +type MCPConfig struct { + ServerURL string `envconfig:"MCP_SERVER_URL"` + ToolCacheTTLSec int `envconfig:"MCP_TOOL_CACHE_TTL_SECONDS" default:"300"` + URL string `envconfig:"MCP_URL" default:""` +} + +// SchedulerConfig holds settings for the scheduler service. +type SchedulerConfig struct { + PollIntervalSeconds int `envconfig:"SCHEDULER_POLL_INTERVAL_SECONDS" default:"30"` + MaxActivePerUser int `envconfig:"SCHEDULER_MAX_ACTIVE_PER_USER" default:"10"` + MinIntervalMinutes int `envconfig:"SCHEDULER_MIN_INTERVAL_MINUTES" default:"60"` +} + // TODO: Add MetricsConfig for Prometheus metrics when metrics are implemented. // Load reads configuration from environment variables. @@ -73,6 +98,9 @@ func (c *Config) Validate() error { if c.Server.Port == "" { c.Server.Port = "8080" } + if c.AuthCache.TTLSeconds <= 0 { + c.AuthCache.TTLSeconds = 180 + } // Add additional validation as needed (e.g., URL format, port ranges) return nil } diff --git a/internal/mcp/client.go b/internal/mcp/client.go new file mode 100644 index 0000000..63adf6a --- /dev/null +++ b/internal/mcp/client.go @@ -0,0 +1,633 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/ai" +) + +// JSON-RPC 2.0 types + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *jsonRPCError) Error() string { + return fmt.Sprintf("mcp rpc error %d: %s", e.Code, e.Message) +} + +// MCP-specific types + +// ToolError is returned when an MCP tool sets IsError: true. +// It carries the tool's text content so callers can still parse structured data from it. +type ToolError struct { + ToolName string + Text string +} + +func (e *ToolError) Error() string { + return fmt.Sprintf("mcp tool %s error: %s", e.ToolName, e.Text) +} + +// MCPTool represents a tool definition from the MCP server. +type MCPTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema any `json:"inputSchema"` +} + +type callToolParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` +} + +type callToolResult struct { + Content []callToolContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type callToolContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// toolCache holds cached MCP tools with a TTL. +type toolCache struct { + mu sync.RWMutex + tools []MCPTool + fetchedAt time.Time + ttl time.Duration +} + +func (tc *toolCache) get() ([]MCPTool, bool) { + tc.mu.RLock() + defer tc.mu.RUnlock() + if tc.tools == nil { + return nil, false + } + fresh := time.Since(tc.fetchedAt) < tc.ttl + return tc.tools, fresh +} + +func (tc *toolCache) set(tools []MCPTool) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.tools = tools + tc.fetchedAt = time.Now() +} + +// MCP resource types (resources/list, resources/read) + +type resourceEntry struct { + URI string `json:"uri"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + MimeType string `json:"mimeType,omitempty"` +} + +type readResourceParams struct { + URI string `json:"uri"` +} + +type readResourceResult struct { + Contents []resourceContent `json:"contents"` +} + +type resourceContent struct { + URI string `json:"uri"` + MimeType string `json:"mimeType,omitempty"` + Text string `json:"text,omitempty"` +} + +// skillEntry is an MCP skill discovered via resources/list. +type skillEntry struct { + Slug string + Name string + Description string + URI string +} + +// skillCache holds cached skill metadata with a TTL. +type skillCache struct { + mu sync.RWMutex + skills []skillEntry + fetchedAt time.Time + ttl time.Duration +} + +func (sc *skillCache) get() ([]skillEntry, bool) { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.skills == nil { + return nil, false + } + fresh := time.Since(sc.fetchedAt) < sc.ttl + return sc.skills, fresh +} + +func (sc *skillCache) set(skills []skillEntry) { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.skills = skills + sc.fetchedAt = time.Now() +} + +// Client is an MCP JSON-RPC 2.0 client using Streamable HTTP transport. +type Client struct { + serverURL string + httpClient *http.Client + sessionID string + requestID atomic.Int64 + cache toolCache + skills skillCache + skillContent sync.Map // slug → string (cached skill markdown) + logger *logrus.Logger +} + +// NewClient creates a new MCP client. +func NewClient(serverURL string, cacheTTL time.Duration, logger *logrus.Logger) *Client { + return &Client{ + serverURL: strings.TrimRight(serverURL, "/"), + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + cache: toolCache{ttl: cacheTTL}, + skills: skillCache{ttl: cacheTTL}, + logger: logger, + } +} + +// call performs a JSON-RPC 2.0 call over HTTP. +func (c *Client) call(ctx context.Context, method string, params any) (json.RawMessage, error) { + id := c.requestID.Add(1) + reqBody := jsonRPCRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + Params: params, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := c.serverURL + "/mcp" + + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_url": url, + "mcp_session": c.sessionID, + }).Debug("mcp request sending") + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + if c.sessionID != "" { + httpReq.Header.Set("Mcp-Session-Id", c.sessionID) + } + + start := time.Now() + resp, err := c.httpClient.Do(httpReq) + elapsed := time.Since(start) + if err != nil { + c.logger.WithError(err).WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_elapsed": elapsed.String(), + }).Error("mcp request failed") + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + // Track session ID from response + if sid := resp.Header.Get("Mcp-Session-Id"); sid != "" { + if c.sessionID != sid { + c.logger.WithFields(logrus.Fields{ + "mcp_session_old": c.sessionID, + "mcp_session_new": sid, + }).Debug("mcp session id updated") + } + c.sessionID = sid + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_id": id, + "mcp_status": resp.StatusCode, + "mcp_elapsed": elapsed.String(), + "mcp_response_len": len(respBody), + }).Debug("mcp response received") + + if resp.StatusCode != http.StatusOK { + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_status": resp.StatusCode, + "mcp_body": string(respBody), + }).Error("mcp server returned non-200 status") + return nil, fmt.Errorf("mcp server returned status %d: %s", resp.StatusCode, string(respBody)) + } + + var rpcResp jsonRPCResponse + if err := json.Unmarshal(respBody, &rpcResp); err != nil { + c.logger.WithError(err).WithField("mcp_body", string(respBody)).Error("mcp response unmarshal failed") + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + if rpcResp.Error != nil { + c.logger.WithFields(logrus.Fields{ + "mcp_method": method, + "mcp_error_code": rpcResp.Error.Code, + "mcp_error_msg": rpcResp.Error.Message, + }).Error("mcp rpc error") + return nil, rpcResp.Error + } + + return rpcResp.Result, nil +} + +// Initialize performs the MCP initialize handshake. +func (c *Client) Initialize(ctx context.Context) error { + c.logger.WithField("mcp_url", c.serverURL).Info("mcp initializing") + + params := map[string]any{ + "protocolVersion": "2025-03-26", + "capabilities": map[string]any{}, + "clientInfo": map[string]string{ + "name": "vultisig-agent-backend", + "version": "1.0.0", + }, + } + + result, err := c.call(ctx, "initialize", params) + if err != nil { + return fmt.Errorf("initialize: %w", err) + } + + c.logger.WithField("mcp_init_result", string(result)).Info("mcp initialized successfully") + + // Send initialized notification — best-effort, some servers don't handle it + if _, err := c.call(ctx, "notifications/initialized", nil); err != nil { + c.logger.WithError(err).Debug("mcp notifications/initialized not supported by server (harmless)") + } + + return nil +} + +// ListTools fetches the tool list from the MCP server and updates the cache. +func (c *Client) ListTools(ctx context.Context) ([]MCPTool, error) { + c.logger.Debug("mcp listing tools") + + result, err := c.call(ctx, "tools/list", nil) + if err != nil { + // Return stale cache on error + if stale, _ := c.cache.get(); stale != nil { + c.logger.WithError(err).WithField("stale_count", len(stale)).Warn("mcp list tools failed, using stale cache") + return stale, nil + } + return nil, fmt.Errorf("list tools: %w", err) + } + + var listResult struct { + Tools []MCPTool `json:"tools"` + } + if err := json.Unmarshal(result, &listResult); err != nil { + return nil, fmt.Errorf("unmarshal tools: %w", err) + } + + names := make([]string, len(listResult.Tools)) + for i, t := range listResult.Tools { + names[i] = t.Name + } + c.logger.WithFields(logrus.Fields{ + "mcp_tool_count": len(listResult.Tools), + "mcp_tool_names": names, + }).Info("mcp tools discovered") + + c.cache.set(listResult.Tools) + return listResult.Tools, nil +} + +// CallTool invokes a tool on the MCP server. +func (c *Client) CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) { + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_arguments": string(arguments), + }).Info("mcp calling tool") + + var args map[string]any + if len(arguments) > 0 { + if err := json.Unmarshal(arguments, &args); err != nil { + return "", fmt.Errorf("unmarshal arguments: %w", err) + } + } + + params := callToolParams{ + Name: name, + Arguments: args, + } + + result, err := c.call(ctx, "tools/call", params) + if err != nil { + c.logger.WithError(err).WithField("mcp_tool", name).Error("mcp tool call failed") + return "", fmt.Errorf("call tool %s: %w", name, err) + } + + var callResult callToolResult + if err := json.Unmarshal(result, &callResult); err != nil { + return "", fmt.Errorf("unmarshal tool result: %w", err) + } + + // Collect text content from result + var texts []string + for _, c := range callResult.Content { + if c.Type == "text" && c.Text != "" { + texts = append(texts, c.Text) + } + } + + text := strings.Join(texts, "\n") + + if callResult.IsError { + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_error": text, + }).Error("mcp tool returned error") + // Return the text with a ToolError so callers can still access the content. + return text, &ToolError{ToolName: name, Text: text} + } + + c.logger.WithFields(logrus.Fields{ + "mcp_tool": name, + "mcp_result_len": len(text), + }).Info("mcp tool call succeeded") + + return text, nil +} + +// GetTools returns cached MCP tools converted to AI tool format. +// If the cache is stale, it attempts a background refresh. +func (c *Client) GetTools(ctx context.Context) []ai.Tool { + tools, fresh := c.cache.get() + + c.logger.WithFields(logrus.Fields{ + "mcp_cache_count": len(tools), + "mcp_cache_fresh": fresh, + }).Debug("mcp GetTools called") + + if !fresh && tools != nil { + c.logger.Debug("mcp cache stale, starting background refresh") + go func() { + refreshCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.ListTools(refreshCtx) + }() + } + + if tools == nil { + c.logger.Warn("mcp cache empty, no tools available") + return nil + } + + result := make([]ai.Tool, len(tools)) + for i, t := range tools { + result[i] = ai.Tool{ + Name: t.Name, + Description: t.Description, + InputSchema: t.InputSchema, + } + } + return result +} + +// ToolNames returns the names of all cached MCP tools. +func (c *Client) ToolNames() []string { + tools, _ := c.cache.get() + names := make([]string, len(tools)) + for i, t := range tools { + names[i] = t.Name + } + return names +} + +// ToolDescriptions returns a formatted string describing all MCP tools for the system prompt. +func (c *Client) ToolDescriptions() string { + tools, _ := c.cache.get() + if len(tools) == 0 { + c.logger.Warn("mcp ToolDescriptions called but cache is empty") + return "" + } + + var b strings.Builder + b.WriteString("\n\n## Vultisig Tools\n\n") + b.WriteString("You have the following tools provided by the Vultisig platform. ") + b.WriteString("These are core capabilities you MUST include when listing your available tools or describing what you can do:\n\n") + for _, t := range tools { + b.WriteString("- **") + b.WriteString(t.Name) + b.WriteString("**") + if t.Description != "" { + b.WriteString(": ") + b.WriteString(t.Description) + } + b.WriteString("\n") + } + + desc := b.String() + c.logger.WithField("mcp_desc_len", len(desc)).Debug("mcp ToolDescriptions generated") + return desc +} + +// --------------------------------------------------------------------------- +// MCP Resources — skill discovery and loading +// --------------------------------------------------------------------------- + +// ListSkills fetches available skills from the MCP server via resources/list. +// Skills are resources with URIs matching "skills/*.md". +func (c *Client) ListSkills(ctx context.Context) ([]skillEntry, error) { + c.logger.Debug("mcp listing skills via resources/list") + + result, err := c.call(ctx, "resources/list", nil) + if err != nil { + if stale, _ := c.skills.get(); stale != nil { + c.logger.WithError(err).Warn("mcp resources/list failed, using stale skill cache") + return stale, nil + } + return nil, fmt.Errorf("list resources: %w", err) + } + + var listResult struct { + Resources []resourceEntry `json:"resources"` + } + if err := json.Unmarshal(result, &listResult); err != nil { + return nil, fmt.Errorf("unmarshal resources: %w", err) + } + + var skills []skillEntry + for _, r := range listResult.Resources { + slug := extractSkillSlug(r.URI) + if slug == "" { + continue + } + skills = append(skills, skillEntry{ + Slug: slug, + Name: r.Name, + Description: r.Description, + URI: r.URI, + }) + } + + slugs := make([]string, len(skills)) + for i, s := range skills { + slugs[i] = s.Slug + } + c.logger.WithFields(logrus.Fields{ + "skill_count": len(skills), + "skill_slugs": slugs, + }).Info("mcp skills discovered") + + c.skills.set(skills) + return skills, nil +} + +// extractSkillSlug extracts a slug from a skill resource URI. +// Handles various URI formats: +// - "skill://vultisig/evm-contract-call.md" → "evm-contract-call" +// - "skills/evm-contract-call.md" → "evm-contract-call" +// +// Returns "" if the URI doesn't end in .md. +func extractSkillSlug(uri string) string { + if !strings.HasSuffix(uri, ".md") { + return "" + } + base := strings.TrimSuffix(uri, ".md") + if idx := strings.LastIndex(base, "/"); idx >= 0 { + return base[idx+1:] + } + return base +} + +// ReadSkill fetches the content of a specific skill by slug. +func (c *Client) ReadSkill(ctx context.Context, slug string) (string, error) { + // Check in-memory content cache first + if cached, ok := c.skillContent.Load(slug); ok { + return cached.(string), nil + } + + // Look up the full URI from the skill cache + uri := c.skillURI(slug) + if uri == "" { + return "", fmt.Errorf("skill %q not found in skill list", slug) + } + + c.logger.WithFields(logrus.Fields{ + "skill": slug, + "uri": uri, + }).Debug("mcp reading skill via resources/read") + + result, err := c.call(ctx, "resources/read", readResourceParams{URI: uri}) + if err != nil { + return "", fmt.Errorf("read skill %s: %w", slug, err) + } + + var readResult readResourceResult + if err := json.Unmarshal(result, &readResult); err != nil { + return "", fmt.Errorf("unmarshal skill content: %w", err) + } + + if len(readResult.Contents) == 0 { + return "", fmt.Errorf("skill %s: empty content", slug) + } + + text := readResult.Contents[0].Text + c.skillContent.Store(slug, text) + + c.logger.WithFields(logrus.Fields{ + "skill": slug, + "content_len": len(text), + }).Info("mcp skill loaded") + + return text, nil +} + +// skillURI looks up the full resource URI for a skill slug from the cache. +func (c *Client) skillURI(slug string) string { + skills, _ := c.skills.get() + for _, s := range skills { + if s.Slug == slug { + return s.URI + } + } + return "" +} + +// SkillSummary returns a formatted list of available skills for injection into the system prompt. +// Returns "" if no skills are available. Triggers a background refresh if cache is stale. +func (c *Client) SkillSummary(ctx context.Context) string { + skills, fresh := c.skills.get() + + if !fresh && skills != nil { + go func() { + refreshCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, _ = c.ListSkills(refreshCtx) + }() + } + + if len(skills) == 0 { + return "" + } + + var b strings.Builder + b.WriteString("\n\n## Available Skills\n\n") + b.WriteString("You have access to specialized skill guides that provide detailed instructions for specific workflows. ") + b.WriteString("Use the `get_skill` tool to load a skill's full instructions when it is relevant to the user's request.\n\n") + b.WriteString("**IMPORTANT**: Only load skills that are directly relevant to what the user is asking. Do not load all skills.\n\n") + for _, s := range skills { + b.WriteString("- **") + b.WriteString(s.Slug) + b.WriteString("**") + if s.Description != "" { + b.WriteString(": ") + b.WriteString(s.Description) + } + b.WriteString("\n") + } + return b.String() +} diff --git a/internal/service/agent/actions.go b/internal/service/agent/actions.go new file mode 100644 index 0000000..9da34db --- /dev/null +++ b/internal/service/agent/actions.go @@ -0,0 +1,31 @@ +package agent + +import ( + "encoding/json" + "fmt" +) + +func buildActionResultUserMessage(result *ActionResult, originalAction *Action) string { + var actionDesc string + if originalAction != nil { + actionDesc = originalAction.Title + } else { + actionDesc = result.Action + } + + if !result.Success { + if result.Error != "" { + return fmt.Sprintf("[Action result: %s failed — %s]", actionDesc, result.Error) + } + return fmt.Sprintf("[Action result: %s failed]", actionDesc) + } + + if len(result.Data) > 0 { + dataJSON, err := json.Marshal(result.Data) + if err == nil { + return fmt.Sprintf("[Action result: %s succeeded — data: %s]", actionDesc, string(dataJSON)) + } + } + + return fmt.Sprintf("[Action result: %s succeeded]", actionDesc) +} diff --git a/internal/service/agent/actions_test.go b/internal/service/agent/actions_test.go new file mode 100644 index 0000000..bfd0165 --- /dev/null +++ b/internal/service/agent/actions_test.go @@ -0,0 +1,57 @@ +package agent + +import ( + "strings" + "testing" +) + +func TestBuildActionResultUserMessage(t *testing.T) { + t.Run("success without data", func(t *testing.T) { + result := &ActionResult{Action: "create_policy", Success: true} + got := buildActionResultUserMessage(result, nil) + if !strings.Contains(got, "succeeded") { + t.Errorf("expected 'succeeded' in %q", got) + } + }) + + t.Run("success with data", func(t *testing.T) { + result := &ActionResult{ + Action: "get_market_price", + Success: true, + Data: map[string]any{"price": 3500.0}, + } + got := buildActionResultUserMessage(result, nil) + if !strings.Contains(got, "3500") { + t.Errorf("expected price data in %q", got) + } + }) + + t.Run("failure with error", func(t *testing.T) { + result := &ActionResult{ + Action: "build_send_tx", + Success: false, + Error: "insufficient balance", + } + got := buildActionResultUserMessage(result, nil) + if !strings.Contains(got, "failed") || !strings.Contains(got, "insufficient balance") { + t.Errorf("expected failure message in %q", got) + } + }) + + t.Run("uses original action title", func(t *testing.T) { + result := &ActionResult{Action: "get_market_price", Success: true} + original := &Action{Title: "Fetch ETH Price"} + got := buildActionResultUserMessage(result, original) + if !strings.Contains(got, "Fetch ETH Price") { + t.Errorf("expected original title in %q", got) + } + }) + + t.Run("failure without error", func(t *testing.T) { + result := &ActionResult{Action: "add_chain", Success: false} + got := buildActionResultUserMessage(result, nil) + if !strings.Contains(got, "failed") { + t.Errorf("expected 'failed' in %q", got) + } + }) +} diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 9c26a12..bb8b70e 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -5,11 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "strings" + "time" "github.com/google/uuid" "github.com/sirupsen/logrus" - "github.com/vultisig/agent-backend/internal/ai/anthropic" + "github.com/vultisig/agent-backend/internal/ai" "github.com/vultisig/agent-backend/internal/cache/redis" "github.com/vultisig/agent-backend/internal/config" "github.com/vultisig/agent-backend/internal/service/verifier" @@ -17,56 +19,108 @@ import ( "github.com/vultisig/agent-backend/internal/types" ) -// PluginSkillsProvider provides plugin skills for prompt building. +const ( + claudeRequestTimeout = 90 * time.Second + actionTTL = 1 * time.Hour + maxLoopIterations = 8 +) + type PluginSkillsProvider interface { GetSkills(ctx context.Context) []PluginSkill } -// AgentService handles AI agent operations. +// MCPToolProvider provides tools and skills discovered from an MCP server. +type MCPToolProvider interface { + GetTools(ctx context.Context) []ai.Tool + ToolNames() []string + CallTool(ctx context.Context, name string, arguments json.RawMessage) (string, error) + ToolDescriptions() string + SkillSummary(ctx context.Context) string + ReadSkill(ctx context.Context, slug string) (string, error) +} + +type SwapTxBuilder interface { + BuildSwapTx(ctx context.Context, req SwapTxBuildRequest) (*SwapTxBuildResponse, error) +} + +type SwapTxBuildRequest struct { + FromChain string `json:"from_chain"` + FromSymbol string `json:"from_symbol"` + FromAddress string `json:"from_address,omitempty"` + FromDecimals *int `json:"from_decimals,omitempty"` + ToChain string `json:"to_chain"` + ToSymbol string `json:"to_symbol"` + ToAddress string `json:"to_address,omitempty"` + ToDecimals *int `json:"to_decimals,omitempty"` + Amount string `json:"amount"` + Sender string `json:"sender"` + Destination string `json:"destination"` +} + +type SwapTxBuildResponse struct { + Provider string `json:"provider"` + ExpectedOutput string `json:"expected_output"` + MinimumOutput string `json:"minimum_output"` + NeedsApproval bool `json:"needs_approval"` + ApprovalTx json.RawMessage `json:"approval_tx,omitempty"` + SwapTx json.RawMessage `json:"swap_tx"` + Memo string `json:"memo,omitempty"` +} + type AgentService struct { - anthropic *anthropic.Client + ai *ai.Client msgRepo *postgres.MessageRepository convRepo *postgres.ConversationRepository memRepo *postgres.MemoryRepository + taskRepo *postgres.ScheduledTaskRepository redis *redis.Client verifier *verifier.Client pluginProvider PluginSkillsProvider + mcpProvider MCPToolProvider + swapTxBuilder SwapTxBuilder logger *logrus.Logger + schedulerCfg config.SchedulerConfig summaryModel string windowSize int summarizeTrigger int summaryMaxTokens int } -// conversationWindow holds a windowed view of conversation messages plus optional summary. type conversationWindow struct { messages []types.Message summary *string total int } -// NewAgentService creates a new AgentService. func NewAgentService( - anthropicClient *anthropic.Client, + aiClient *ai.Client, msgRepo *postgres.MessageRepository, convRepo *postgres.ConversationRepository, memRepo *postgres.MemoryRepository, + taskRepo *postgres.ScheduledTaskRepository, redisClient *redis.Client, verifierClient *verifier.Client, pluginProvider PluginSkillsProvider, + mcpProvider MCPToolProvider, + swapTxBuilder SwapTxBuilder, logger *logrus.Logger, summaryModel string, ctxCfg config.ContextConfig, + schedulerCfg config.SchedulerConfig, ) *AgentService { return &AgentService{ - anthropic: anthropicClient, + ai: aiClient, msgRepo: msgRepo, convRepo: convRepo, memRepo: memRepo, + taskRepo: taskRepo, redis: redisClient, verifier: verifierClient, pluginProvider: pluginProvider, + mcpProvider: mcpProvider, + swapTxBuilder: swapTxBuilder, logger: logger, + schedulerCfg: schedulerCfg, summaryModel: summaryModel, windowSize: ctxCfg.WindowSize, summarizeTrigger: ctxCfg.SummarizeTrigger, @@ -74,13 +128,11 @@ func NewAgentService( } } -const maxLoopIterations = 8 - -// ProcessMessage handles a user message through an LLM-driven decision loop. -// The LLM freely picks tools and chains them together until it produces a final text response. func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, publicKey string, req *SendMessageRequest) (*SendMessageResponse, error) { - // 1. Validate conversation exists and belongs to user - _, err := s.convRepo.GetByID(ctx, convID, publicKey) + ctx, cancel := context.WithTimeout(ctx, claudeRequestTimeout) + defer cancel() + + conv, err := s.convRepo.GetByID(ctx, convID, publicKey) if err != nil { if errors.Is(err, postgres.ErrNotFound) { return nil, fmt.Errorf("conversation not found") @@ -88,166 +140,470 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub return nil, fmt.Errorf("get conversation: %w", err) } - // 2. Load conversation window + // Prime MCP session with vault info if set + if conv.VaultInfo != nil && s.mcpProvider != nil { + vaultInput, _ := json.Marshal(conv.VaultInfo) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", vaultInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + } else { + s.logger.Debug("mcp session primed with vault info") + } + } + window, err := s.getConversationWindow(ctx, convID, publicKey) if err != nil { return nil, fmt.Errorf("get conversation window: %w", err) } - // 3. Convert request to user message text userContent := s.resolveUserContent(ctx, req) - // 4. Store user message in DB - contentType := "text" - if req.ActionResult != nil { - contentType = "action_result" - } userMsg := &types.Message{ ConversationID: convID, Role: types.RoleUser, Content: userContent, - ContentType: contentType, + ContentType: s.resolveContentType(req), } if err := s.msgRepo.Create(ctx, userMsg); err != nil { return nil, fmt.Errorf("store user message: %w", err) } - // 5. Build unified system prompt - var balances []Balance - var addresses map[string]string - if req.Context != nil { - balances = req.Context.Balances - addresses = req.Context.Addresses - } + fullCtx := s.resolveContext(ctx, convID, req.Context) - var pluginSkills []PluginSkill - if s.pluginProvider != nil { - pluginSkills = s.pluginProvider.GetSkills(ctx) + basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + if conv.VaultInfo != nil { + basePrompt += BuildVaultInfoSection(conv.VaultInfo) } + if s.mcpProvider != nil { + mcpDesc := s.mcpProvider.ToolDescriptions() + if mcpDesc != "" { + basePrompt += mcpDesc + } + skillSummary := s.mcpProvider.SkillSummary(ctx) + if skillSummary != "" { + s.logger.WithField("skill_summary_len", len(skillSummary)).Debug("appending skill summary to system prompt") + basePrompt += skillSummary + } + } + systemPrompt := BuildSystemPromptWithSummary( + basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, + window.summary, + ) - basePrompt := BuildFullPrompt(balances, addresses, pluginSkills) - basePrompt += s.loadMemorySection(ctx, req.PublicKey) + MemoryManagementInstructions - systemPrompt := BuildSystemPromptWithSummary(basePrompt, window.summary) - - // 6. Build messages array from window + new user message - messages := anthropicMessagesFromWindow(window) - messages = append(messages, anthropic.Message{ + messages := aiMessagesFromWindow(window) + messages = append(messages, ai.Message{ Role: "user", Content: userContent, }) - // 7. Gather all tools tools := agentTools() tools = append(tools, s.memoryTools()...) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetTools(ctx) + if len(mcpTools) > 0 { + mcpNames := make([]string, len(mcpTools)) + for i, t := range mcpTools { + mcpNames[i] = t.Name + } + s.logger.WithFields(logrus.Fields{ + "mcp_tool_count": len(mcpTools), + "mcp_tool_names": mcpNames, + }).Debug("appending mcp tools to ai request") + } else { + s.logger.Warn("mcp provider active but no tools returned") + } + tools = append(tools, mcpTools...) + + // Add get_skill tool if skills are available + if s.mcpProvider.SkillSummary(ctx) != "" { + tools = append(tools, GetSkillTool) + } + } - // 8. Run decision loop + var toolResp *ToolResponse var textContent string - var suggestions []Suggestion - var policyReady *PolicyReady - var installRequired *InstallRequired + var tokens *TokenSearchResult for i := 0; i < maxLoopIterations; i++ { - anthropicReq := &anthropic.Request{ + aiReq := &ai.Request{ + Model: req.Model, System: systemPrompt, Messages: messages, Tools: tools, - ToolChoice: &anthropic.ToolChoice{ + ToolChoice: &ai.ToolChoice{ Type: "auto", }, } - resp, err := s.anthropic.SendMessage(ctx, anthropicReq) + resp, err := s.ai.SendMessage(ctx, aiReq) if err != nil { - return nil, fmt.Errorf("call anthropic (iteration %d): %w", i, err) + return nil, fmt.Errorf("call ai (iteration %d): %w", i, err) } - s.logger.WithFields(logrus.Fields{ - "iteration": i, - "stop_reason": resp.StopReason, - "input_tokens": resp.Usage.InputTokens, - "output_tokens": resp.Usage.OutputTokens, - "content_blocks": len(resp.Content), - }).Debug("decision loop iteration") - - // Collect text content - for _, block := range resp.Content { - if block.Type == "text" && block.Text != "" { - if textContent != "" { - textContent += "\n\n" - } - textContent += block.Text - } - } + s.persistMemoryUpdate(ctx, req.PublicKey, s.extractMemoryUpdate(resp)) - // If end_turn, we're done - if resp.StopReason == "end_turn" { + assistantText := resp.Content + toolCalls := resp.ToolCalls + + if resp.FinishReason == "stop" || len(toolCalls) == 0 { + textContent = assistantText break } - // Execute tool calls and build results - var toolUseBlocks []anthropic.ContentBlock - var toolResults []anthropic.ToolResultBlock - - for _, block := range resp.Content { - if block.Type != "tool_use" { + var toolMessages []ai.ToolMessage + for _, tc := range toolCalls { + if tc.Function.Name == "respond_to_user" { + var tr ToolResponse + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { + toolResp = &tr + } + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, + }) continue } - toolUseBlocks = append(toolUseBlocks, block) - - result, execErr := s.executeTool(ctx, block.Name, block.Input, req) - if execErr != nil { - result = jsonError(execErr.Error()) + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) + if err != nil { + result = jsonError(err.Error()) } s.logger.WithFields(logrus.Fields{ - "tool": block.Name, - "tool_id": block.ID, - "result_len": len(result), + "tool": tc.Function.Name, + "tool_id": tc.ID, }).Debug("tool executed") - toolResults = append(toolResults, anthropic.ToolResultBlock{ - Type: "tool_result", - ToolUseID: block.ID, - Content: result, + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, }) - // Track results for response assembly - s.trackToolResult(block.Name, result, &suggestions, &policyReady, &installRequired) + // Track find_token results for structured passthrough + if tc.Function.Name == "find_token" { + if parsed := extractTokens(result); parsed != nil { + tokens = parsed + s.logger.WithField("token_count", len(parsed.Tokens)).Info("tokens extracted from find_token result") + } else { + s.logger.WithField("result_preview", truncateResult(result, 200)).Warn("find_token result could not be parsed as token data") + } + } + } + + messages = append(messages, ai.AssistantMessage{ + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, + }) + for _, tm := range toolMessages { + messages = append(messages, tm) } - if len(toolUseBlocks) == 0 { - // No tool calls and not end_turn — break to avoid infinite loop + if toolResp != nil { break } + } - // Append assistant message + tool results for next iteration - messages = append(messages, anthropic.AssistantMessage{ - Role: "assistant", - Content: resp.Content, - }) - messages = append(messages, anthropic.ToolResultMessage{ - Role: "user", - Content: toolResults, + if toolResp != nil { + resp, err := s.buildLoopResponse(ctx, convID, req, toolResp, window) + if err != nil { + return nil, err + } + resp.Tokens = tokens + return resp, nil + } + if textContent != "" { + resp, err := s.buildTextResponse(ctx, convID, textContent) + if err != nil { + return nil, err + } + resp.Tokens = tokens + return resp, nil + } + + return nil, errors.New("no response content from Claude") +} + +func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUID, publicKey string, req *SendMessageRequest, eventCh chan<- SSEEvent) { + defer close(eventCh) + + sendErr := func(msg string) { + eventCh <- SSEEvent{Event: "error", Data: ErrorPayload{Error: msg}} + } + + ctx, cancel := context.WithTimeout(ctx, claudeRequestTimeout) + defer cancel() + + conv, err := s.convRepo.GetByID(ctx, convID, publicKey) + if err != nil { + sendErr("conversation not found") + return + } + + // Prime MCP session with vault info if set + if conv.VaultInfo != nil && s.mcpProvider != nil { + vaultInput, _ := json.Marshal(conv.VaultInfo) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", vaultInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + } + } + + window, err := s.getConversationWindow(ctx, convID, publicKey) + if err != nil { + sendErr("failed to load conversation") + return + } + + userContent := s.resolveUserContent(ctx, req) + + userMsg := &types.Message{ + ConversationID: convID, + Role: types.RoleUser, + Content: userContent, + ContentType: s.resolveContentType(req), + } + err = s.msgRepo.Create(ctx, userMsg) + if err != nil { + sendErr("failed to store message") + return + } + + fullCtx := s.resolveContext(ctx, convID, req.Context) + + basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + if conv.VaultInfo != nil { + basePrompt += BuildVaultInfoSection(conv.VaultInfo) + } + if s.mcpProvider != nil { + mcpDesc := s.mcpProvider.ToolDescriptions() + if mcpDesc != "" { + basePrompt += mcpDesc + } + skillSummary := s.mcpProvider.SkillSummary(ctx) + if skillSummary != "" { + s.logger.WithField("skill_summary_len", len(skillSummary)).Debug("appending skill summary to system prompt (stream)") + basePrompt += skillSummary + } + } + systemPrompt := BuildSystemPromptWithSummary( + basePrompt+s.loadMemorySection(ctx, req.PublicKey)+MemoryManagementInstructions, + window.summary, + ) + + messages := aiMessagesFromWindow(window) + messages = append(messages, ai.Message{ + Role: "user", + Content: userContent, + }) + + tools := agentTools() + tools = append(tools, s.memoryTools()...) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetTools(ctx) + if len(mcpTools) > 0 { + tools = append(tools, mcpTools...) + } + if s.mcpProvider.SkillSummary(ctx) != "" { + tools = append(tools, GetSkillTool) + } + } + + var toolResp *ToolResponse + var textContent string + + for range maxLoopIterations { + aiReq := &ai.Request{ + Model: req.Model, + System: systemPrompt, + Messages: messages, + Tools: tools, + ToolChoice: &ai.ToolChoice{ + Type: "auto", + }, + } + + extractor := ai.NewResponseFieldExtractor() + callback := func(delta ai.StreamDelta) { + if delta.Content != "" { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: delta.Content}} + } + for _, tc := range delta.ToolCalls { + if tc.Function.Arguments != "" { + text := extractor.Feed(tc.Function.Arguments) + if text != "" { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: text}} + } + } + } + } + + resp, err := s.ai.SendMessageStream(ctx, aiReq, callback) + if err != nil { + sendErr(fmt.Sprintf("AI error: %v", err)) + return + } + + s.persistMemoryUpdate(ctx, req.PublicKey, s.extractMemoryUpdate(resp)) + + assistantText := resp.Content + toolCalls := resp.ToolCalls + + if resp.FinishReason == "stop" || len(toolCalls) == 0 { + textContent = assistantText + break + } + + var toolMessages []ai.ToolMessage + for _, tc := range toolCalls { + if tc.Function.Name == "respond_to_user" { + var tr ToolResponse + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { + toolResp = &tr + } + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, + }) + continue + } + + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) + if err != nil { + result = jsonError(err.Error()) + } + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, + }) + } + + messages = append(messages, ai.AssistantMessage{ + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, }) + for _, tm := range toolMessages { + messages = append(messages, tm) + } + + if toolResp != nil { + break + } + } + + if toolResp != nil { + s.emitLoopResponse(ctx, convID, req, toolResp, window, eventCh) + } else if textContent != "" { + s.emitTextResponse(ctx, convID, textContent, eventCh) + } else { + sendErr("no response from AI") + } + + eventCh <- SSEEvent{Event: "done", Data: struct{}{}} +} + +func (s *AgentService) resolveUserContent(ctx context.Context, req *SendMessageRequest) string { + if req.ActionResult != nil { + var originalAction *Action + if req.ActionResult.ActionID != "" { + actJSON, err := s.redis.Get(ctx, req.ActionResult.ActionID) + if err == nil && actJSON != "" { + var act Action + if err := json.Unmarshal([]byte(actJSON), &act); err == nil { + originalAction = &act + } + _ = s.redis.Delete(ctx, req.ActionResult.ActionID) + } + } + return buildActionResultUserMessage(req.ActionResult, originalAction) + } + + if req.SelectedSuggestionID != nil { + suggJSON, err := s.redis.Get(ctx, *req.SelectedSuggestionID) + if err == nil && suggJSON != "" { + var sugg Suggestion + if err := json.Unmarshal([]byte(suggJSON), &sugg); err == nil { + return fmt.Sprintf("[User selected suggestion: %s — %s (plugin: %s)]", sugg.Title, sugg.Description, sugg.PluginID) + } + } + return fmt.Sprintf("[User selected suggestion: %s]", *req.SelectedSuggestionID) + } + + return req.Content +} + +func (s *AgentService) resolveContentType(req *SendMessageRequest) string { + if req.ActionResult != nil { + return "action_result" } + return "text" +} - // 9. Fallback if no text was generated - if textContent == "" { - textContent = "I'm here to help! What would you like to do?" +func (s *AgentService) buildLoopResponse(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse, window *conversationWindow) (*SendMessageResponse, error) { + var suggestions []Suggestion + for _, ts := range toolResp.Suggestions { + suggID := "sug_" + uuid.New().String() + sugg := Suggestion{ + ID: suggID, + PluginID: ts.PluginID, + Title: ts.Title, + Description: ts.Description, + } + suggestions = append(suggestions, sugg) + + suggJSON, err := json.Marshal(sugg) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal suggestion") + continue + } + if err := s.redis.Set(ctx, suggID, string(suggJSON), suggestionTTL); err != nil { + s.logger.WithError(err).Warn("failed to store suggestion in redis") + } + } + + var actions []Action + for _, ta := range toolResp.Actions { + actID := "act_" + uuid.New().String() + act := Action{ + ID: actID, + Type: ta.Type, + Title: ta.Title, + Description: ta.Description, + Params: ta.Params, + AutoExecute: ta.AutoExecute, + } + actions = append(actions, act) + + actJSON, err := json.Marshal(act) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal action") + continue + } + if err := s.redis.Set(ctx, actID, string(actJSON), actionTTL); err != nil { + s.logger.WithError(err).Warn("failed to store action in redis") + } } - // 10. Store assistant message in DB - var metadata json.RawMessage + txReady := s.interceptBuildSwapTx(ctx, convID, req, toolResp) + + metadataMap := map[string]any{ + "intent": toolResp.Intent, + } if len(suggestions) > 0 { - metadata, _ = json.Marshal(map[string]any{ - "suggestions": suggestions, - }) + metadataMap["suggestions"] = suggestions + } + if len(actions) > 0 { + metadataMap["actions"] = actions } + metadata, _ := json.Marshal(metadataMap) + assistantMsg := &types.Message{ ConversationID: convID, Role: types.RoleAssistant, - Content: textContent, + Content: toolResp.Response, ContentType: "text", Metadata: metadata, } @@ -255,127 +611,455 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub return nil, fmt.Errorf("store assistant message: %w", err) } - // 11. Update conversation title on first exchange - if window.total <= 2 && req.Content != "" { - title := truncateTitle(req.Content) + var titlePtr *string + if toolResp.ConversationTitle != "" { + title := truncateTitle(toolResp.ConversationTitle, 60) + titlePtr = &title if err := s.convRepo.UpdateTitle(ctx, convID, req.PublicKey, title); err != nil { s.logger.WithError(err).Warn("failed to update conversation title") } } + resp := &SendMessageResponse{ + Message: *assistantMsg, + Title: titlePtr, + Suggestions: suggestions, + Actions: actions, + TxReady: txReady, + } + + if req.ActionResult != nil && req.ActionResult.Action == "install_plugin" && req.ActionResult.Success { + s.autoContinueAfterInstall(ctx, convID, req, window, resp) + } + + s.logger.WithField("num_actions", len(actions)).Info("sending response to desktop") + return resp, nil +} + +func (s *AgentService) buildTextResponse(ctx context.Context, convID uuid.UUID, text string) (*SendMessageResponse, error) { + assistantMsg := &types.Message{ + ConversationID: convID, + Role: types.RoleAssistant, + Content: text, + ContentType: "text", + } + if err := s.msgRepo.Create(ctx, assistantMsg); err != nil { + return nil, fmt.Errorf("store assistant message: %w", err) + } return &SendMessageResponse{ - Message: *assistantMsg, - Suggestions: suggestions, - PolicyReady: policyReady, - InstallRequired: installRequired, + Message: *assistantMsg, }, nil } -// resolveUserContent converts the request fields into a user message string. -func (s *AgentService) resolveUserContent(ctx context.Context, req *SendMessageRequest) string { - if req.ActionResult != nil { - return buildActionResultMessage(req.ActionResult) +func (s *AgentService) emitLoopResponse(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse, window *conversationWindow, eventCh chan<- SSEEvent) { + var suggestions []Suggestion + for _, ts := range toolResp.Suggestions { + suggID := "sug_" + uuid.New().String() + sugg := Suggestion{ + ID: suggID, + PluginID: ts.PluginID, + Title: ts.Title, + Description: ts.Description, + } + suggestions = append(suggestions, sugg) + + suggJSON, err := json.Marshal(sugg) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal suggestion") + continue + } + if err := s.redis.Set(ctx, suggID, string(suggJSON), suggestionTTL); err != nil { + s.logger.WithError(err).Warn("failed to store suggestion in redis") + } } - if req.SelectedSuggestionID != nil { - suggJSON, err := s.redis.Get(ctx, *req.SelectedSuggestionID) - if err == nil && suggJSON != "" { - var suggestion Suggestion - if err := json.Unmarshal([]byte(suggJSON), &suggestion); err == nil { - return fmt.Sprintf("I'd like to proceed with: %s - %s (plugin: %s)", suggestion.Title, suggestion.Description, suggestion.PluginID) + var actions []Action + for _, ta := range toolResp.Actions { + actID := "act_" + uuid.New().String() + act := Action{ + ID: actID, + Type: ta.Type, + Title: ta.Title, + Description: ta.Description, + Params: ta.Params, + AutoExecute: ta.AutoExecute, + } + actions = append(actions, act) + + actJSON, err := json.Marshal(act) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal action") + continue + } + if err := s.redis.Set(ctx, actID, string(actJSON), actionTTL); err != nil { + s.logger.WithError(err).Warn("failed to store action in redis") + } + } + + txReady := s.interceptBuildSwapTx(ctx, convID, req, toolResp) + + metadataMap := map[string]any{ + "intent": toolResp.Intent, + } + if len(suggestions) > 0 { + metadataMap["suggestions"] = suggestions + } + if len(actions) > 0 { + metadataMap["actions"] = actions + } + metadata, _ := json.Marshal(metadataMap) + + assistantMsg := &types.Message{ + ConversationID: convID, + Role: types.RoleAssistant, + Content: toolResp.Response, + ContentType: "text", + Metadata: metadata, + } + err := s.msgRepo.Create(ctx, assistantMsg) + if err != nil { + s.logger.WithError(err).Error("failed to store assistant message") + eventCh <- SSEEvent{Event: "error", Data: ErrorPayload{Error: "failed to store message"}} + return + } + + if toolResp.ConversationTitle != "" { + title := truncateTitle(toolResp.ConversationTitle, 60) + if err := s.convRepo.UpdateTitle(ctx, convID, req.PublicKey, title); err != nil { + s.logger.WithError(err).Warn("failed to update conversation title") + } + eventCh <- SSEEvent{Event: "title", Data: TitlePayload{Title: title}} + } + + if len(suggestions) > 0 { + eventCh <- SSEEvent{Event: "suggestions", Data: SuggestionsPayload{Suggestions: suggestions}} + } + if len(actions) > 0 { + eventCh <- SSEEvent{Event: "actions", Data: ActionsPayload{Actions: actions}} + } + if txReady != nil { + eventCh <- SSEEvent{Event: "tx_ready", Data: txReady} + } + + eventCh <- SSEEvent{Event: "message", Data: MessagePayload{Message: *assistantMsg}} + + if req.ActionResult != nil && req.ActionResult.Action == "install_plugin" && req.ActionResult.Success { + pendingKey := fmt.Sprintf("pending_build:%s", convID) + suggID, err := s.redis.Get(ctx, pendingKey) + if err == nil && suggID != "" { + _ = s.redis.Delete(ctx, pendingKey) + buildReq := &SendMessageRequest{ + SelectedSuggestionID: &suggID, + Context: req.Context, + AccessToken: req.AccessToken, + PublicKey: req.PublicKey, + } + buildResp, buildErr := s.ProcessMessage(ctx, convID, req.PublicKey, buildReq) + if buildErr != nil { + s.logger.WithError(buildErr).Warn("auto-continue to policy build failed") + } else { + if buildResp.PolicyReady != nil { + eventCh <- SSEEvent{Event: "text_delta", Data: TextDeltaPayload{Delta: buildResp.Message.Content}} + eventCh <- SSEEvent{Event: "policy_ready", Data: *buildResp.PolicyReady} + eventCh <- SSEEvent{Event: "message", Data: MessagePayload{Message: buildResp.Message}} + } } } - return fmt.Sprintf("I'd like to proceed with suggestion: %s", *req.SelectedSuggestionID) } +} - return req.Content +func (s *AgentService) emitTextResponse(ctx context.Context, convID uuid.UUID, text string, eventCh chan<- SSEEvent) { + assistantMsg := &types.Message{ + ConversationID: convID, + Role: types.RoleAssistant, + Content: text, + ContentType: "text", + } + err := s.msgRepo.Create(ctx, assistantMsg) + if err != nil { + eventCh <- SSEEvent{Event: "error", Data: ErrorPayload{Error: "failed to store message"}} + return + } + eventCh <- SSEEvent{Event: "message", Data: MessagePayload{Message: *assistantMsg}} } -// buildActionResultMessage creates a user message describing the action result. -func buildActionResultMessage(result *ActionResult) string { - if result.Success { - return fmt.Sprintf("[Action completed: %s was successful]", result.Action) +// extractTokens tries to parse a TokenSearchResult from an MCP tool result. +// MCP text content may not be pure JSON (e.g., multiple text blocks joined with \n, +// or descriptive text surrounding JSON), so we try multiple strategies. +func extractTokens(result string) *TokenSearchResult { + // Strategy 1: direct unmarshal (pure JSON) + var direct TokenSearchResult + if err := json.Unmarshal([]byte(result), &direct); err == nil && len(direct.Tokens) > 0 { + return &direct } - if result.Error != "" { - return fmt.Sprintf("[Action failed: %s failed with error: %s]", result.Action, result.Error) + + // Strategy 2: the result may contain non-JSON text around the JSON object. + // Scan for the first '{' and try to decode from there. json.Decoder + // stops after the first complete JSON value, ignoring trailing text. + for i := strings.IndexByte(result, '{'); i >= 0 && i < len(result); { + var candidate TokenSearchResult + dec := json.NewDecoder(strings.NewReader(result[i:])) + if err := dec.Decode(&candidate); err == nil && len(candidate.Tokens) > 0 { + return &candidate + } + // Try the next '{' occurrence + next := strings.IndexByte(result[i+1:], '{') + if next < 0 { + break + } + i = i + 1 + next + } + + return nil +} + +// truncateResult returns the first n bytes of a string for log previews. +func truncateResult(s string, n int) string { + if len(s) <= n { + return s } - return fmt.Sprintf("[Action failed: %s was not successful]", result.Action) + return s[:n] + "..." } -// trackToolResult maps tool results to response fields for the frontend. -// NOTE: check_billing_status can overwrite installRequired set by check_plugin_installed (and vice versa). -// The system prompt instructs the LLM to call check_plugin_installed first and check_billing_status second, -// so in practice the last writer wins with the correct value. If ordering becomes unreliable, switch to first-write-wins. -func (s *AgentService) trackToolResult(toolName string, result string, suggestions *[]Suggestion, policyReady **PolicyReady, installRequired **InstallRequired) { - switch toolName { - case "create_suggestion": - var sugg Suggestion - if err := json.Unmarshal([]byte(result), &sugg); err == nil && sugg.ID != "" { - *suggestions = append(*suggestions, sugg) - } - - case "suggest_policy": - var data struct { - PluginID string `json:"plugin_id"` - Configuration map[string]any `json:"configuration"` - PolicySuggest any `json:"policy_suggest"` - } - if err := json.Unmarshal([]byte(result), &data); err == nil && data.PluginID != "" { - *policyReady = &PolicyReady{ - PluginID: data.PluginID, - Configuration: data.Configuration, - PolicySuggest: data.PolicySuggest, - } +func (s *AgentService) autoContinueAfterInstall(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, window *conversationWindow, resp *SendMessageResponse) { + pendingKey := fmt.Sprintf("pending_build:%s", convID) + suggID, err := s.redis.Get(ctx, pendingKey) + if err != nil || suggID == "" { + return + } + + _ = s.redis.Delete(ctx, pendingKey) + buildReq := &SendMessageRequest{ + SelectedSuggestionID: &suggID, + Context: req.Context, + AccessToken: req.AccessToken, + PublicKey: req.PublicKey, + } + buildResp, err := s.ProcessMessage(ctx, convID, req.PublicKey, buildReq) + if err != nil { + s.logger.WithError(err).Warn("auto-continue to policy build failed") + return + } + + resp.PolicyReady = buildResp.PolicyReady + resp.InstallRequired = buildResp.InstallRequired +} + +func (s *AgentService) getPluginSkills(ctx context.Context) []PluginSkill { + if s.pluginProvider == nil { + return nil + } + return s.pluginProvider.GetSkills(ctx) +} + +func (s *AgentService) BuildTx(ctx context.Context, convID uuid.UUID, req *BuildTxRequest) (*BuildTxResponse, error) { + if s.swapTxBuilder == nil { + return nil, fmt.Errorf("swap builder not configured") + } + + fullCtx := s.resolveContext(ctx, convID, req.Context) + + sendReq := &SendMessageRequest{ + PublicKey: req.PublicKey, + Context: fullCtx, + } + toolResp := &ToolResponse{ + Actions: []ToolAction{ + { + Type: "build_tx", + Params: req.Params, + }, + }, + } + + txReady := s.interceptBuildSwapTx(ctx, convID, sendReq, toolResp) + + resp := &BuildTxResponse{ + TxReady: txReady, + } + + for _, ta := range toolResp.Actions { + resp.Actions = append(resp.Actions, Action{ + ID: "act_" + uuid.New().String(), + Type: ta.Type, + Title: ta.Title, + Description: ta.Description, + Params: ta.Params, + AutoExecute: ta.AutoExecute, + }) + } + + if txReady == nil { + msg := strings.TrimSpace(toolResp.Response) + if len(resp.Actions) > 0 { + resp.Message = msg + } else { + resp.Error = msg } + } - case "check_plugin_installed": - var data struct { - Installed bool `json:"installed"` - PluginID string `json:"plugin_id"` + return resp, nil +} + +func (s *AgentService) interceptBuildSwapTx(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse) *TxReady { + for i, action := range toolResp.Actions { + if action.Type != "build_tx" { + continue } - if err := json.Unmarshal([]byte(result), &data); err == nil && !data.Installed { - *installRequired = &InstallRequired{ - PluginID: data.PluginID, - Title: data.PluginID, - Description: fmt.Sprintf("Install %s to set up your automation", data.PluginID), - } + + fromChain := getStringParam(action.Params, "from_chain") + fromSymbol := getStringParam(action.Params, "from_symbol") + toChain := getStringParam(action.Params, "to_chain") + toSymbol := getStringParam(action.Params, "to_symbol") + amount := getStringParam(action.Params, "amount") + + if fromChain == "" || fromSymbol == "" || toChain == "" || toSymbol == "" || amount == "" { + s.logger.Warn("build_tx missing required params") + toolResp.Actions = append(toolResp.Actions[:i], toolResp.Actions[i+1:]...) + toolResp.Response += "\n\nCould not build swap transaction: missing required parameters." + continue } - case "check_billing_status": - var data struct { - BillingOK bool `json:"billing_ok"` - BillingPluginID string `json:"billing_plugin_id"` + var sender, destination string + if req.Context != nil && req.Context.Addresses != nil { + sender = findAddress(req.Context.Addresses, fromChain) + destination = findAddress(req.Context.Addresses, toChain) } - if err := json.Unmarshal([]byte(result), &data); err == nil && !data.BillingOK { - *installRequired = &InstallRequired{ - PluginID: data.BillingPluginID, - Title: "Vultisig Billing App", - Description: "Install the billing app to continue using plugins after your free trial", - } + if sender == "" || destination == "" { + s.logger.WithField("from_chain", fromChain).WithField("to_chain", toChain). + WithField("addresses", req.Context.Addresses). + Warn("build_tx: missing sender/destination addresses") + toolResp.Actions = append(toolResp.Actions[:i], toolResp.Actions[i+1:]...) + toolResp.Response += "\n\nCould not build swap transaction: missing wallet addresses." + continue + } + + _, fromDecimals := resolveTokenFromContext(req.Context, fromChain, fromSymbol) + _, toDecimals := resolveTokenFromContext(req.Context, toChain, toSymbol) + + buildReq := SwapTxBuildRequest{ + FromChain: fromChain, + FromSymbol: fromSymbol, + FromDecimals: fromDecimals, + ToChain: toChain, + ToSymbol: toSymbol, + ToDecimals: toDecimals, + Amount: amount, + Sender: sender, + Destination: destination, + } + + buildResp, err := s.swapTxBuilder.BuildSwapTx(ctx, buildReq) + if err != nil { + s.logger.WithError(err).Warn("build_tx failed") + toolResp.Actions = append(toolResp.Actions[:i], toolResp.Actions[i+1:]...) + toolResp.Response += "\n\nFailed to build swap transaction: " + err.Error() + continue + } + + toolResp.Actions[i].Type = "sign_swap_tx" + + var txFields struct { + To string `json:"to"` + Value string `json:"value"` + Data string `json:"data"` + } + unmarshalErr := json.Unmarshal(buildResp.SwapTx, &txFields) + if unmarshalErr == nil && txFields.Data != "" { + toolResp.Actions = append(toolResp.Actions, ToolAction{ + Type: "scan_tx", + AutoExecute: true, + Params: map[string]any{ + "chain": fromChain, + "from": sender, + "to": txFields.To, + "value": txFields.Value, + "data": txFields.Data, + }, + }) + } + + fromDec := 18 + if fromDecimals != nil { + fromDec = *fromDecimals + } + toDec := 18 + if toDecimals != nil { + toDec = *toDecimals + } + + return &TxReady{ + Provider: buildResp.Provider, + ExpectedOutput: buildResp.ExpectedOutput, + MinimumOutput: buildResp.MinimumOutput, + NeedsApproval: buildResp.NeedsApproval, + ApprovalTx: buildResp.ApprovalTx, + SwapTx: buildResp.SwapTx, + FromChain: fromChain, + FromSymbol: fromSymbol, + FromDecimals: fromDec, + ToChain: toChain, + ToSymbol: toSymbol, + ToDecimals: toDec, + Amount: amount, + Sender: sender, + Destination: destination, + } + } + + return nil +} + +func truncateTitle(title string, maxLen int) string { + if len(title) <= maxLen { + return title + } + return title[:maxLen-3] + "..." +} + +func resolveTokenFromContext(msgCtx *MessageContext, chain, symbol string) (string, *int) { + if msgCtx == nil { + return "", nil + } + for _, coin := range msgCtx.Coins { + if strings.EqualFold(coin.Chain, chain) && strings.EqualFold(coin.Ticker, symbol) { + dec := coin.Decimals + return coin.ContractAddress, &dec + } + } + return "", nil +} + +func findAddress(addresses map[string]string, chain string) string { + if addr := addresses[chain]; addr != "" { + return addr + } + for k, v := range addresses { + if strings.EqualFold(k, chain) { + return v } } + return "" } -// truncateTitle truncates content to create a conversation title. -func truncateTitle(content string) string { - const maxLen = 50 - if len(content) <= maxLen { - return content +func getStringParam(params map[string]any, key string) string { + v, ok := params[key] + if !ok || v == nil { + return "" } - return content[:maxLen-3] + "..." + str, ok := v.(string) + if !ok { + return fmt.Sprintf("%v", v) + } + return str } -// getConversationWindow returns a windowed view of the conversation. -// Uses a summary_up_to cursor to only count/load messages after the last summarization point. -// This prevents re-summarizing on every request once the trigger threshold is crossed. func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UUID, publicKey string) (*conversationWindow, error) { - // Load summary and cursor together summary, cursor, err := s.convRepo.GetSummaryWithCursor(ctx, convID, publicKey) if err != nil { return nil, fmt.Errorf("get summary with cursor: %w", err) } - // Cursor-aware path: only count messages after the cursor if cursor != nil { count, err := s.msgRepo.CountSince(ctx, convID, *cursor) if err != nil { @@ -390,7 +1074,6 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU "has_cursor": true, }).Debug("context window state") - // Active messages fit in window — load all since cursor if count <= s.windowSize { msgs, err := s.msgRepo.GetSince(ctx, convID, *cursor) if err != nil { @@ -399,34 +1082,23 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU return &conversationWindow{messages: msgs, summary: summary, total: count}, nil } - // Active messages exceed trigger — re-summarize if count > s.summarizeTrigger { allSinceCursor, err := s.msgRepo.GetSince(ctx, convID, *cursor) if err != nil { return nil, fmt.Errorf("get messages since cursor: %w", err) } - if err := s.summarizeOldMessages(ctx, convID, publicKey, allSinceCursor); err != nil { - s.logger.WithError(err).Error("synchronous summarization failed") - // Fall back to recent window + existing summary - } - - // Reload summary+cursor after summarization (cursor has advanced) - summary, cursor, err = s.convRepo.GetSummaryWithCursor(ctx, convID, publicKey) - if err != nil { - return nil, fmt.Errorf("get summary after summarization: %w", err) - } - - if cursor != nil { - recentMsgs, err := s.msgRepo.GetRecentSince(ctx, convID, *cursor, s.windowSize) - if err != nil { - return nil, fmt.Errorf("get recent messages since cursor: %w", err) + msgsCopy := make([]types.Message, len(allSinceCursor)) + copy(msgsCopy, allSinceCursor) + go func() { + bgCtx, bgCancel := context.WithTimeout(context.Background(), claudeRequestTimeout) + defer bgCancel() + if err := s.summarizeOldMessages(bgCtx, convID, publicKey, msgsCopy); err != nil { + s.logger.WithError(err).Error("async summarization failed") } - return &conversationWindow{messages: recentMsgs, summary: summary, total: len(recentMsgs)}, nil - } + }() } - // Between window and trigger — load recent since cursor msgs, err := s.msgRepo.GetRecentSince(ctx, convID, *cursor, s.windowSize) if err != nil { return nil, fmt.Errorf("get recent messages since cursor: %w", err) @@ -434,7 +1106,6 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU return &conversationWindow{messages: msgs, summary: summary, total: count}, nil } - // No cursor — first summarization hasn't happened yet total, err := s.msgRepo.CountByConversationID(ctx, convID) if err != nil { return nil, fmt.Errorf("count messages: %w", err) @@ -448,7 +1119,6 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU "has_cursor": false, }).Debug("context window state") - // All messages fit in window if total <= s.windowSize { msgs, err := s.msgRepo.GetByConversationID(ctx, convID) if err != nil { @@ -457,33 +1127,22 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU return &conversationWindow{messages: msgs, total: total}, nil } - // Past trigger — first-time summarization if total > s.summarizeTrigger { allMsgs, err := s.msgRepo.GetByConversationID(ctx, convID) if err != nil { return nil, fmt.Errorf("get messages: %w", err) } - if err := s.summarizeOldMessages(ctx, convID, publicKey, allMsgs); err != nil { - s.logger.WithError(err).Error("synchronous summarization failed") - return &conversationWindow{messages: allMsgs, total: total}, nil - } - - // Reload summary+cursor after first summarization - summary, cursor, err = s.convRepo.GetSummaryWithCursor(ctx, convID, publicKey) - if err != nil { - return nil, fmt.Errorf("get summary after summarization: %w", err) - } - - if cursor != nil { - recentMsgs, err := s.msgRepo.GetRecentSince(ctx, convID, *cursor, s.windowSize) - if err != nil { - return nil, fmt.Errorf("get recent messages since cursor: %w", err) + msgsCopy := make([]types.Message, len(allMsgs)) + copy(msgsCopy, allMsgs) + go func() { + bgCtx, bgCancel := context.WithTimeout(context.Background(), claudeRequestTimeout) + defer bgCancel() + if err := s.summarizeOldMessages(bgCtx, convID, publicKey, msgsCopy); err != nil { + s.logger.WithError(err).Error("async summarization failed") } - return &conversationWindow{messages: recentMsgs, summary: summary, total: len(recentMsgs)}, nil - } + }() - // Fallback if cursor wasn't set (shouldn't happen) recentMsgs, err := s.msgRepo.GetRecent(ctx, convID, s.windowSize) if err != nil { return nil, fmt.Errorf("get recent messages: %w", err) @@ -491,7 +1150,6 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU return &conversationWindow{messages: recentMsgs, summary: summary, total: total}, nil } - // Between window and trigger, no cursor yet — load all messages msgs, err := s.msgRepo.GetByConversationID(ctx, convID) if err != nil { return nil, fmt.Errorf("get messages: %w", err) @@ -499,23 +1157,18 @@ func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UU return &conversationWindow{messages: msgs, total: total}, nil } -// summarizeOldMessages summarizes messages outside the recent window and stores the summary. -// It runs synchronously and advances the summary_up_to cursor to the last summarized message. func (s *AgentService) summarizeOldMessages(ctx context.Context, convID uuid.UUID, publicKey string, allMsgs []types.Message) error { if len(allMsgs) <= s.windowSize { return nil } - // Split: old messages to summarize, recent window to keep oldMsgs := allMsgs[:len(allMsgs)-s.windowSize] - // Build content to summarize var oldContent string for _, msg := range oldMsgs { oldContent += fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content) } - // Include existing summary for incremental summarization existingSummary, _, _ := s.convRepo.GetSummaryWithCursor(ctx, convID, publicKey) prompt := SummarizationPrompt if existingSummary != nil { @@ -524,33 +1177,24 @@ func (s *AgentService) summarizeOldMessages(ctx context.Context, convID uuid.UUI prompt += "\n\n## Messages to Summarize\n\n" + oldContent // Call Claude Haiku for summarization - req := &anthropic.Request{ + req := &ai.Request{ Model: s.summaryModel, MaxTokens: s.summaryMaxTokens, Messages: []any{ - anthropic.Message{Role: "user", Content: prompt}, + ai.Message{Role: "user", Content: prompt}, }, } - resp, err := s.anthropic.SendMessage(ctx, req) + resp, err := s.ai.SendMessage(ctx, req) if err != nil { - return fmt.Errorf("call anthropic: %w", err) - } - - // Extract text response - var summaryText string - for _, block := range resp.Content { - if block.Type == "text" { - summaryText = block.Text - break - } + return fmt.Errorf("call ai: %w", err) } + summaryText := resp.Content if summaryText == "" { - return fmt.Errorf("empty response from anthropic") + return fmt.Errorf("empty response from ai") } - // Advance cursor to the last summarized message's timestamp summaryUpTo := oldMsgs[len(oldMsgs)-1].CreatedAt if err := s.convRepo.UpdateSummaryWithCursor(ctx, convID, publicKey, summaryText, summaryUpTo); err != nil { return fmt.Errorf("store summary with cursor: %w", err) @@ -564,15 +1208,15 @@ func (s *AgentService) summarizeOldMessages(ctx context.Context, convID uuid.UUI return nil } -// anthropicMessagesFromWindow converts conversation window messages to Anthropic message format, +// aiMessagesFromWindow converts conversation window messages to AI message format, // skipping system messages. -func anthropicMessagesFromWindow(window *conversationWindow) []any { +func aiMessagesFromWindow(window *conversationWindow) []any { msgs := make([]any, 0, len(window.messages)) for _, msg := range window.messages { if msg.Role == types.RoleSystem { continue } - msgs = append(msgs, anthropic.Message{ + msgs = append(msgs, ai.Message{ Role: string(msg.Role), Content: msg.Content, }) diff --git a/internal/service/agent/context.go b/internal/service/agent/context.go new file mode 100644 index 0000000..c7d127b --- /dev/null +++ b/internal/service/agent/context.go @@ -0,0 +1,93 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" +) + +const vaultContextTTL = 24 * time.Hour + +type cachedVaultContext struct { + VaultName string `json:"vault_name,omitempty"` + Coins []CoinInfo `json:"coins,omitempty"` + AddressBook []AddressBookEntry `json:"address_book,omitempty"` +} + +func vaultContextKey(convID uuid.UUID) string { + return fmt.Sprintf("vault_ctx:%s", convID) +} + +func (s *AgentService) cacheVaultContext(ctx context.Context, convID uuid.UUID, msgCtx *MessageContext) { + if msgCtx == nil { + return + } + if len(msgCtx.Coins) == 0 && len(msgCtx.AddressBook) == 0 && msgCtx.VaultName == "" { + return + } + + cached := cachedVaultContext{ + VaultName: msgCtx.VaultName, + Coins: msgCtx.Coins, + AddressBook: msgCtx.AddressBook, + } + + data, err := json.Marshal(cached) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal vault context for cache") + return + } + + err = s.redis.Set(ctx, vaultContextKey(convID), string(data), vaultContextTTL) + if err != nil { + s.logger.WithError(err).Warn("failed to cache vault context in redis") + } +} + +func (s *AgentService) loadCachedVaultContext(ctx context.Context, convID uuid.UUID) *cachedVaultContext { + raw, err := s.redis.Get(ctx, vaultContextKey(convID)) + if err != nil || raw == "" { + return nil + } + + var cached cachedVaultContext + err = json.Unmarshal([]byte(raw), &cached) + if err != nil { + s.logger.WithError(err).Warn("failed to unmarshal cached vault context") + return nil + } + return &cached +} + +func (s *AgentService) resolveContext(ctx context.Context, convID uuid.UUID, msgCtx *MessageContext) *MessageContext { + s.cacheVaultContext(ctx, convID, msgCtx) + cached := s.loadCachedVaultContext(ctx, convID) + return mergeContext(cached, msgCtx) +} + +func mergeContext(cached *cachedVaultContext, msg *MessageContext) *MessageContext { + if msg == nil { + msg = &MessageContext{} + } + + merged := *msg + + if cached == nil { + return &merged + } + + if merged.VaultName == "" { + merged.VaultName = cached.VaultName + } + if len(merged.Coins) == 0 { + merged.Coins = cached.Coins + } + if len(merged.AddressBook) == 0 { + merged.AddressBook = cached.AddressBook + } + + return &merged +} diff --git a/internal/service/agent/context_test.go b/internal/service/agent/context_test.go new file mode 100644 index 0000000..1ca5a8e --- /dev/null +++ b/internal/service/agent/context_test.go @@ -0,0 +1,69 @@ +package agent + +import ( + "testing" +) + +func TestMergeContext(t *testing.T) { + t.Run("both nil", func(t *testing.T) { + got := mergeContext(nil, nil) + if got == nil { + t.Fatal("expected non-nil result") + } + }) + + t.Run("cached nil, msg with data", func(t *testing.T) { + msg := &MessageContext{VaultName: "MyVault"} + got := mergeContext(nil, msg) + if got.VaultName != "MyVault" { + t.Errorf("got %q, want MyVault", got.VaultName) + } + }) + + t.Run("cached fills empty msg fields", func(t *testing.T) { + cached := &cachedVaultContext{ + VaultName: "CachedVault", + Coins: []CoinInfo{{Ticker: "ETH", Chain: "Ethereum"}}, + } + msg := &MessageContext{} + got := mergeContext(cached, msg) + if got.VaultName != "CachedVault" { + t.Errorf("VaultName: got %q, want CachedVault", got.VaultName) + } + if len(got.Coins) != 1 { + t.Errorf("Coins: got %d, want 1", len(got.Coins)) + } + }) + + t.Run("msg fields take precedence", func(t *testing.T) { + cached := &cachedVaultContext{ + VaultName: "CachedVault", + Coins: []CoinInfo{{Ticker: "ETH"}}, + } + msg := &MessageContext{ + VaultName: "FreshVault", + Coins: []CoinInfo{{Ticker: "BTC"}, {Ticker: "ETH"}}, + } + got := mergeContext(cached, msg) + if got.VaultName != "FreshVault" { + t.Errorf("VaultName: got %q, want FreshVault", got.VaultName) + } + if len(got.Coins) != 2 { + t.Errorf("Coins: got %d, want 2", len(got.Coins)) + } + }) + + t.Run("msg nil uses cached", func(t *testing.T) { + cached := &cachedVaultContext{ + VaultName: "CachedVault", + AddressBook: []AddressBookEntry{{Title: "Alice", Address: "0x123", Chain: "Ethereum"}}, + } + got := mergeContext(cached, nil) + if got.VaultName != "CachedVault" { + t.Errorf("VaultName: got %q, want CachedVault", got.VaultName) + } + if len(got.AddressBook) != 1 { + t.Errorf("AddressBook: got %d, want 1", len(got.AddressBook)) + } + }) +} diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index c7aefd5..3bc900e 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -3,16 +3,23 @@ package agent import ( "context" "encoding/json" + "errors" + "fmt" "time" "github.com/google/uuid" + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/mcp" + "github.com/vultisig/agent-backend/internal/storage/postgres" + "github.com/vultisig/agent-backend/internal/types" ) const suggestionTTL = 1 * time.Hour // executeTool dispatches a tool call to the appropriate handler. // Returns a JSON string result for Claude. Errors are returned as JSON {"error": "..."} so the LLM can communicate them naturally. -func (s *AgentService) executeTool(ctx context.Context, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { +func (s *AgentService) executeTool(ctx context.Context, convID uuid.UUID, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { switch name { case "check_plugin_installed": return s.execCheckPluginInstalled(ctx, input, req) @@ -26,11 +33,120 @@ func (s *AgentService) executeTool(ctx context.Context, name string, input json. return s.execCreateSuggestion(ctx, input) case "update_memory": return s.execUpdateMemory(ctx, input, req) + case "get_skill": + return s.execGetSkill(ctx, input) + case "set_vault": + return s.execSetVault(ctx, convID, input, req) + case "schedule_task": + return s.execScheduleTask(ctx, convID, input, req) + case "list_scheduled_tasks": + return s.execListScheduledTasks(ctx, req) + case "update_scheduled_task": + return s.execUpdateScheduledTask(ctx, input, req) + case "cancel_scheduled_task": + return s.execCancelScheduledTask(ctx, input, req) default: + // Check MCP tools before returning unknown + if s.mcpProvider != nil { + for _, mcpName := range s.mcpProvider.ToolNames() { + if mcpName == name { + result, err := s.mcpProvider.CallTool(ctx, name, input) + if err != nil { + // ToolError carries the text content — return it so + // trackToolResult can still extract structured data + // and Claude can narrate the error to the user. + var toolErr *mcp.ToolError + if errors.As(err, &toolErr) && toolErr.Text != "" { + s.logger.WithField("tool", name).Warn("mcp tool returned isError with content") + return toolErr.Text, nil + } + s.logger.WithError(err).WithField("tool", name).Warn("mcp tool call failed") + return jsonError("mcp tool error: " + err.Error()), nil + } + return result, nil + } + } + } return jsonError("unknown tool: " + name), nil } } +// execSetVault stores vault keys for this conversation and primes the MCP session. +func (s *AgentService) execSetVault(ctx context.Context, convID uuid.UUID, input json.RawMessage, req *SendMessageRequest) (string, error) { + var params struct { + ECDSAPublicKey string `json:"ecdsa_public_key"` + EDDSAPublicKey string `json:"eddsa_public_key"` + ChainCode string `json:"chain_code"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + + if params.ECDSAPublicKey == "" || params.EDDSAPublicKey == "" || params.ChainCode == "" { + return jsonError("all three fields are required: ecdsa_public_key, eddsa_public_key, chain_code"), nil + } + + // Store in DB + if err := s.convRepo.UpdateVaultInfo(ctx, convID, req.PublicKey, params.ECDSAPublicKey, params.EDDSAPublicKey, params.ChainCode); err != nil { + s.logger.WithError(err).Error("failed to store vault info") + return jsonError("failed to store vault info: " + err.Error()), nil + } + + s.logger.WithFields(logrus.Fields{ + "conversation_id": convID, + "ecdsa_prefix": truncateKey(params.ECDSAPublicKey), + "eddsa_prefix": truncateKey(params.EDDSAPublicKey), + "chain_code_prefix": truncateKey(params.ChainCode), + }).Info("vault info set for conversation") + + // Prime MCP session with vault info + if s.mcpProvider != nil { + mcpInput, _ := json.Marshal(params) + if _, err := s.mcpProvider.CallTool(ctx, "set_vault_info", mcpInput); err != nil { + s.logger.WithError(err).Warn("failed to prime mcp session with vault info") + // Non-fatal: vault is stored locally, MCP will be primed on next request + } + } + + result, _ := json.Marshal(map[string]any{ + "ok": true, + "message": "Vault set for this conversation.", + }) + return string(result), nil +} + +// execGetSkill loads a skill's full instructions from the MCP server. +func (s *AgentService) execGetSkill(ctx context.Context, input json.RawMessage) (string, error) { + var params struct { + SkillName string `json:"skill_name"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + if params.SkillName == "" { + return jsonError("skill_name is required"), nil + } + if s.mcpProvider == nil { + return jsonError("skills not available"), nil + } + + content, err := s.mcpProvider.ReadSkill(ctx, params.SkillName) + if err != nil { + s.logger.WithError(err).WithField("skill", params.SkillName).Warn("failed to load skill") + return jsonError("failed to load skill: " + err.Error()), nil + } + + return content, nil +} + +// truncateKey returns the first 12 chars of a key for logging. +func truncateKey(key string) string { + if len(key) <= 12 { + return key + } + return key[:12] + "..." +} + // execCheckPluginInstalled checks if a plugin is installed for the user's vault. func (s *AgentService) execCheckPluginInstalled(ctx context.Context, input json.RawMessage, req *SendMessageRequest) (string, error) { var params struct { @@ -146,10 +262,12 @@ func (s *AgentService) execSuggestPolicy(ctx context.Context, input json.RawMess // Convert fromAmount to base units using balance decimals var balances []Balance + var coins []CoinInfo if req.Context != nil { balances = req.Context.Balances + coins = req.Context.Coins } - convertAmountToBaseUnits(params.Configuration, balances) + convertAmountToBaseUnits(params.Configuration, balances, coins) policySuggest, err := s.verifier.GetPolicySuggest(ctx, params.PluginID, params.Configuration) if err != nil { @@ -214,3 +332,192 @@ func jsonError(msg string) string { result, _ := json.Marshal(map[string]string{"error": msg}) return string(result) } + +// execScheduleTask creates a new scheduled task. +func (s *AgentService) execScheduleTask(ctx context.Context, convID uuid.UUID, input json.RawMessage, req *SendMessageRequest) (string, error) { + var params struct { + Intent string `json:"intent"` + Context map[string]any `json:"context"` + NextRunAt string `json:"next_run_at"` + IntervalSeconds *int32 `json:"interval_seconds"` + MaxRuns *int32 `json:"max_runs"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + + if params.Intent == "" { + return jsonError("intent is required"), nil + } + if params.NextRunAt == "" { + return jsonError("next_run_at is required"), nil + } + + nextRunAt, err := time.Parse(time.RFC3339, params.NextRunAt) + if err != nil { + return jsonError("invalid next_run_at format, use ISO 8601 (e.g., '2025-03-01T09:00:00Z')"), nil + } + if nextRunAt.Before(time.Now()) { + return jsonError("next_run_at must be in the future"), nil + } + + // Validate interval for recurring tasks + if params.IntervalSeconds != nil { + if *params.IntervalSeconds < int32(s.schedulerCfg.MinIntervalMinutes*60) { + return jsonError(fmt.Sprintf("minimum interval is %d minutes (%d seconds)", s.schedulerCfg.MinIntervalMinutes, s.schedulerCfg.MinIntervalMinutes*60)), nil + } + } + + // Check active task limit + count, err := s.taskRepo.CountActive(ctx, req.PublicKey) + if err != nil { + return jsonError("failed to check task limits: " + err.Error()), nil + } + if count >= int64(s.schedulerCfg.MaxActivePerUser) { + return jsonError(fmt.Sprintf("maximum of %d active tasks reached", s.schedulerCfg.MaxActivePerUser)), nil + } + + contextJSON, err := json.Marshal(params.Context) + if err != nil { + return jsonError("failed to marshal context: " + err.Error()), nil + } + + task := &types.ScheduledTask{ + PublicKey: req.PublicKey, + Intent: params.Intent, + Context: contextJSON, + NextRunAt: nextRunAt, + IntervalSeconds: params.IntervalSeconds, + MaxRuns: params.MaxRuns, + } + if convID != uuid.Nil { + task.ConversationID = &convID + } + + created, err := s.taskRepo.Create(ctx, task) + if err != nil { + return jsonError("failed to create task: " + err.Error()), nil + } + + result, _ := json.Marshal(map[string]any{ + "task_id": created.ID.String(), + "status": created.Status, + "recurring": created.IsRecurring(), + "next_run_at": created.NextRunAt, + }) + return string(result), nil +} + +// execListScheduledTasks lists the user's active scheduled tasks. +func (s *AgentService) execListScheduledTasks(ctx context.Context, req *SendMessageRequest) (string, error) { + tasks, err := s.taskRepo.ListActive(ctx, req.PublicKey) + if err != nil { + return jsonError("failed to list tasks: " + err.Error()), nil + } + + type taskSummary struct { + ID string `json:"id"` + Intent string `json:"intent"` + Context json.RawMessage `json:"context"` + Recurring bool `json:"recurring"` + NextRunAt time.Time `json:"next_run_at"` + IntervalSeconds *int32 `json:"interval_seconds,omitempty"` + MaxRuns *int32 `json:"max_runs,omitempty"` + RunCount int32 `json:"run_count"` + } + + summaries := make([]taskSummary, len(tasks)) + for i, t := range tasks { + summaries[i] = taskSummary{ + ID: t.ID.String(), + Intent: t.Intent, + Context: t.Context, + Recurring: t.IsRecurring(), + NextRunAt: t.NextRunAt, + IntervalSeconds: t.IntervalSeconds, + MaxRuns: t.MaxRuns, + RunCount: t.RunCount, + } + } + + result, _ := json.Marshal(map[string]any{ + "tasks": summaries, + "count": len(summaries), + }) + return string(result), nil +} + +// execUpdateScheduledTask updates an existing scheduled task. +func (s *AgentService) execUpdateScheduledTask(ctx context.Context, input json.RawMessage, req *SendMessageRequest) (string, error) { + var params struct { + TaskID string `json:"task_id"` + Intent *string `json:"intent"` + Context json.RawMessage `json:"context"` + NextRunAt *string `json:"next_run_at"` + IntervalSeconds *int32 `json:"interval_seconds"` + MaxRuns *int32 `json:"max_runs"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + + taskID, err := uuid.Parse(params.TaskID) + if err != nil { + return jsonError("invalid task_id"), nil + } + + updateParams := &postgres.UpdateScheduledTaskParams{ + Intent: params.Intent, + IntervalSeconds: params.IntervalSeconds, + MaxRuns: params.MaxRuns, + } + + if params.Context != nil { + updateParams.Context = []byte(params.Context) + } + + if params.NextRunAt != nil { + nextRunAt, err := time.Parse(time.RFC3339, *params.NextRunAt) + if err != nil { + return jsonError("invalid next_run_at format"), nil + } + updateParams.NextRunAt = &nextRunAt + } + + updated, err := s.taskRepo.Update(ctx, taskID, req.PublicKey, updateParams) + if err != nil { + return jsonError("failed to update task: " + err.Error()), nil + } + + result, _ := json.Marshal(map[string]any{ + "task_id": updated.ID.String(), + "status": "updated", + "next_run_at": updated.NextRunAt, + }) + return string(result), nil +} + +// execCancelScheduledTask cancels a scheduled task. +func (s *AgentService) execCancelScheduledTask(ctx context.Context, input json.RawMessage, req *SendMessageRequest) (string, error) { + var params struct { + TaskID string `json:"task_id"` + } + if err := json.Unmarshal(input, ¶ms); err != nil { + return jsonError("invalid input: " + err.Error()), nil + } + + taskID, err := uuid.Parse(params.TaskID) + if err != nil { + return jsonError("invalid task_id"), nil + } + + if err := s.taskRepo.Cancel(ctx, taskID, req.PublicKey); err != nil { + return jsonError("failed to cancel task: " + err.Error()), nil + } + + result, _ := json.Marshal(map[string]any{ + "task_id": taskID.String(), + "status": "cancelled", + }) + return string(result), nil +} diff --git a/internal/service/agent/headless.go b/internal/service/agent/headless.go new file mode 100644 index 0000000..d5ddecd --- /dev/null +++ b/internal/service/agent/headless.go @@ -0,0 +1,147 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/ai" +) + +const headlessTimeout = 120 * time.Second + +// ToolCallLog records a single tool call for auditing. +type ToolCallLog struct { + Name string `json:"name"` + Input string `json:"input"` + Result string `json:"result"` +} + +// HeadlessResult is the output of a headless (non-interactive) Claude execution. +type HeadlessResult struct { + Response string `json:"response"` + ToolLogs []ToolCallLog `json:"tool_logs,omitempty"` +} + +// RunHeadless executes a Claude conversation without user interaction. +// Used by the scheduler to process scheduled tasks. +func (s *AgentService) RunHeadless(ctx context.Context, systemPrompt, userMessage, model string, tools []ai.Tool, publicKey string) (*HeadlessResult, error) { + ctx, cancel := context.WithTimeout(ctx, headlessTimeout) + defer cancel() + + messages := []any{ + ai.Message{ + Role: "user", + Content: userMessage, + }, + } + + // Build a minimal SendMessageRequest for tool executors that need it + req := &SendMessageRequest{ + PublicKey: publicKey, + } + + var toolResp *ToolResponse + var textContent string + var toolLogs []ToolCallLog + + for i := 0; i < maxLoopIterations; i++ { + aiReq := &ai.Request{ + Model: model, + System: systemPrompt, + Messages: messages, + Tools: tools, + ToolChoice: &ai.ToolChoice{ + Type: "auto", + }, + } + + resp, err := s.ai.SendMessage(ctx, aiReq) + if err != nil { + return nil, fmt.Errorf("call ai (iteration %d): %w", i, err) + } + + assistantText := resp.Content + toolCalls := resp.ToolCalls + + if resp.FinishReason == "stop" || len(toolCalls) == 0 { + textContent = assistantText + break + } + + var toolMessages []ai.ToolMessage + for _, tc := range toolCalls { + if tc.Function.Name == "respond_to_user" { + var tr ToolResponse + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &tr); err == nil { + toolResp = &tr + } + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: `{"ok": true}`, + }) + toolLogs = append(toolLogs, ToolCallLog{ + Name: tc.Function.Name, + Input: tc.Function.Arguments, + Result: `{"ok": true}`, + }) + continue + } + + result, err := s.executeTool(ctx, uuid.Nil, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) + if err != nil { + result = jsonError(err.Error()) + } + + s.logger.WithFields(logrus.Fields{ + "tool": tc.Function.Name, + "tool_id": tc.ID, + "headless": true, + }).Debug("tool executed") + + toolLogs = append(toolLogs, ToolCallLog{ + Name: tc.Function.Name, + Input: tc.Function.Arguments, + Result: result, + }) + + toolMessages = append(toolMessages, ai.ToolMessage{ + Role: "tool", + ToolCallID: tc.ID, + Content: result, + }) + } + + messages = append(messages, ai.AssistantMessage{ + Role: "assistant", + Content: assistantText, + ToolCalls: toolCalls, + }) + for _, tm := range toolMessages { + messages = append(messages, tm) + } + + if toolResp != nil { + break + } + } + + response := textContent + if toolResp != nil { + response = toolResp.Response + } + + if response == "" { + return nil, fmt.Errorf("no response from Claude") + } + + return &HeadlessResult{ + Response: response, + ToolLogs: toolLogs, + }, nil +} diff --git a/internal/service/agent/memory.go b/internal/service/agent/memory.go index df408b2..f2a1ef0 100644 --- a/internal/service/agent/memory.go +++ b/internal/service/agent/memory.go @@ -2,10 +2,11 @@ package agent import ( "context" + "encoding/json" "github.com/sirupsen/logrus" - "github.com/vultisig/agent-backend/internal/ai/anthropic" + "github.com/vultisig/agent-backend/internal/ai" ) const maxMemoryBytes = 4000 @@ -60,11 +61,23 @@ func (s *AgentService) persistMemoryUpdate(ctx context.Context, publicKey string } } +func (s *AgentService) extractMemoryUpdate(resp *ai.Response) *updateMemoryInput { + for _, tc := range resp.ToolCalls { + if tc.Function.Name == "update_memory" { + var mu updateMemoryInput + if err := json.Unmarshal(json.RawMessage(tc.Function.Arguments), &mu); err == nil && mu.Content != "" { + return &mu + } + } + } + return nil +} + // memoryTools returns the update_memory tool if memRepo is configured, for appending to ability tool lists. -func (s *AgentService) memoryTools() []anthropic.Tool { +func (s *AgentService) memoryTools() []ai.Tool { if s.memRepo == nil { return nil } - return []anthropic.Tool{UpdateMemoryTool} + return []ai.Tool{UpdateMemoryTool} } diff --git a/internal/service/agent/policy.go b/internal/service/agent/policy.go index 60b4294..d4578cc 100644 --- a/internal/service/agent/policy.go +++ b/internal/service/agent/policy.go @@ -6,39 +6,47 @@ import ( "strings" ) -// convertAmountToBaseUnits converts fromAmount in the configuration from human-readable -// format (e.g. "3.5") to base units (e.g. "3500000" for 6-decimal tokens like USDC). -// It extracts the token address from the nested from.token field and matches it against -// the user's balances to find the correct decimals. -func convertAmountToBaseUnits(config map[string]any, balances []Balance) { +func convertAmountToBaseUnits(config map[string]any, balances []Balance, coins []CoinInfo) (matched bool) { amountVal, ok := config["fromAmount"] if !ok { - return + return true } amountStr := fmt.Sprintf("%v", amountVal) - // Find decimals from balances by matching from.token - decimals := 18 // default to 18 (ETH-like) - if from, ok := config["from"].(map[string]any); ok { - if token, ok := from["token"].(string); ok && token != "" { + decimals := 18 + matched = false + + from, ok := config["from"].(map[string]any) + if ok { + token, _ := from["token"].(string) + if token != "" { for _, b := range balances { - if strings.EqualFold(b.Asset, token) { + if strings.EqualFold(b.Asset, token) || strings.EqualFold(b.Symbol, token) { decimals = b.Decimals + matched = true break } } + + if !matched { + for _, c := range coins { + if strings.EqualFold(c.ContractAddress, token) || strings.EqualFold(c.Ticker, token) { + decimals = c.Decimals + matched = true + break + } + } + } } } baseUnits := toBaseUnits(amountStr, decimals) config["fromAmount"] = baseUnits + return matched } -// toBaseUnits converts a human-readable decimal string to base units. -// e.g. toBaseUnits("3.5", 6) returns "3500000" func toBaseUnits(amount string, decimals int) string { - // Split on decimal point parts := strings.SplitN(amount, ".", 2) whole := parts[0] frac := "" @@ -46,18 +54,16 @@ func toBaseUnits(amount string, decimals int) string { frac = parts[1] } - // Pad or truncate fractional part to exactly `decimals` digits if len(frac) > decimals { frac = frac[:decimals] } else { frac = frac + strings.Repeat("0", decimals-len(frac)) } - // Combine and parse as big.Int to strip leading zeros raw := whole + frac result, ok := new(big.Int).SetString(raw, 10) if !ok { - return amount // return original if parsing fails + return amount } return result.String() } diff --git a/internal/service/agent/policy_test.go b/internal/service/agent/policy_test.go new file mode 100644 index 0000000..0d9bf1f --- /dev/null +++ b/internal/service/agent/policy_test.go @@ -0,0 +1,136 @@ +package agent + +import ( + "testing" +) + +func TestToBaseUnits(t *testing.T) { + tests := []struct { + name string + amount string + decimals int + want string + }{ + {"whole number 6 decimals", "10", 6, "10000000"}, + {"decimal 6 decimals", "3.5", 6, "3500000"}, + {"small decimal 6 decimals", "0.000001", 6, "1"}, + {"whole number 18 decimals", "1", 18, "1000000000000000000"}, + {"decimal 18 decimals", "0.5", 18, "500000000000000000"}, + {"zero", "0", 6, "0"}, + {"whole number 8 decimals", "1", 8, "100000000"}, + {"decimal 8 decimals", "0.00000001", 8, "1"}, + {"truncates excess precision", "1.1234567890", 6, "1123456"}, + {"large amount", "1000000", 6, "1000000000000"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := toBaseUnits(tt.amount, tt.decimals) + if got != tt.want { + t.Errorf("toBaseUnits(%q, %d) = %q, want %q", tt.amount, tt.decimals, got, tt.want) + } + }) + } +} + +func TestConvertAmountToBaseUnits(t *testing.T) { + t.Run("matches by asset", func(t *testing.T) { + config := map[string]any{ + "fromAmount": "10", + "from": map[string]any{ + "token": "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48", + }, + } + balances := []Balance{ + {Asset: "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48", Symbol: "USDC", Decimals: 6}, + } + matched := convertAmountToBaseUnits(config, balances, nil) + if !matched { + t.Error("expected match") + } + if config["fromAmount"] != "10000000" { + t.Errorf("got %v, want 10000000", config["fromAmount"]) + } + }) + + t.Run("matches by symbol", func(t *testing.T) { + config := map[string]any{ + "fromAmount": "5", + "from": map[string]any{ + "token": "USDC", + }, + } + balances := []Balance{ + {Asset: "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48", Symbol: "USDC", Decimals: 6}, + } + matched := convertAmountToBaseUnits(config, balances, nil) + if !matched { + t.Error("expected match") + } + if config["fromAmount"] != "5000000" { + t.Errorf("got %v, want 5000000", config["fromAmount"]) + } + }) + + t.Run("matches by coin contract address", func(t *testing.T) { + config := map[string]any{ + "fromAmount": "1", + "from": map[string]any{ + "token": "0x2260fac5e5542a773aa44fbcfedf7c193bc2c599", + }, + } + coins := []CoinInfo{ + {ContractAddress: "0x2260fac5e5542a773aa44fbcfedf7c193bc2c599", Ticker: "WBTC", Decimals: 8}, + } + matched := convertAmountToBaseUnits(config, nil, coins) + if !matched { + t.Error("expected match") + } + if config["fromAmount"] != "100000000" { + t.Errorf("got %v, want 100000000", config["fromAmount"]) + } + }) + + t.Run("matches by coin ticker", func(t *testing.T) { + config := map[string]any{ + "fromAmount": "2", + "from": map[string]any{ + "token": "WBTC", + }, + } + coins := []CoinInfo{ + {Ticker: "WBTC", Decimals: 8}, + } + matched := convertAmountToBaseUnits(config, nil, coins) + if !matched { + t.Error("expected match") + } + if config["fromAmount"] != "200000000" { + t.Errorf("got %v, want 200000000", config["fromAmount"]) + } + }) + + t.Run("no match defaults to 18", func(t *testing.T) { + config := map[string]any{ + "fromAmount": "1", + "from": map[string]any{ + "token": "0xunknown", + }, + } + matched := convertAmountToBaseUnits(config, nil, nil) + if matched { + t.Error("expected no match") + } + if config["fromAmount"] != "1000000000000000000" { + t.Errorf("got %v, want 1000000000000000000", config["fromAmount"]) + } + }) + + t.Run("no fromAmount is noop", func(t *testing.T) { + config := map[string]any{"other": "value"} + matched := convertAmountToBaseUnits(config, nil, nil) + if !matched { + t.Error("expected matched=true when no fromAmount") + } + }) +} diff --git a/internal/service/agent/prompt.go b/internal/service/agent/prompt.go index 55d58bf..2e3b48a 100644 --- a/internal/service/agent/prompt.go +++ b/internal/service/agent/prompt.go @@ -1,92 +1,400 @@ package agent import ( + "fmt" "strings" + "time" - "github.com/vultisig/agent-backend/internal/ai/anthropic" + "github.com/vultisig/agent-backend/internal/ai" + "github.com/vultisig/agent-backend/internal/types" ) -// SystemPrompt is the unified system prompt for the Vultisig AI assistant. -const SystemPrompt = `You are the Vultisig AI assistant, integrated into the Vultisig mobile wallet app. You help users manage their crypto assets through natural conversation. +// ActionsTable is the shared actions reference used by both the system prompt and starters. +// Adding a new action here automatically makes it available to conversation starters. +const ActionsTable = `| Type | Params | auto_execute | +|---|---|---| +| get_market_price | asset, fiat (default "usd") | true | +| get_balances | chain (optional) | true | +| get_portfolio | fiat (default "usd") | true | +| add_chain | chain | true | +| add_coin | chain, ticker, contract_address (optional) | true | +| search_token | query, chain (optional) | true | +| remove_coin | chain, ticker | true | +| remove_chain | chain | true | +| build_send_tx | chain, symbol, address, amount, memo (optional) | true | +| build_custom_tx | tx_type, chain, symbol, amount, memo, contract_address, function_name, params, value, execute_msg, funds | true | +| build_tx | from_chain, from_symbol, from_contract, from_decimals, to_chain, to_symbol, to_contract, to_decimals, amount | true | +| sign_tx | (no params needed — app fills from stored tx) | true | +| list_vaults | | true | +| plugin_install | plugin_id | true | +| create_policy | plugin_id, configuration | true | +| delete_policy | policy_id | true | +| address_book_add | name, chain, address | true | +| address_book_remove | name, chain | true | +| read_evm_contract | chain, contract_address, function_name, params, output_types | true | +| scan_tx | chain, from, to, value, data | true | +| thorchain_query | query_type, asset (optional) | true |` -## About Vultisig +// SystemPrompt is the base system prompt for the Vultisig AI assistant. +var SystemPrompt = `You are the Vultisig AI assistant inside the Vultisig mobile wallet app. -Vultisig is a **self-custodial, seedless cryptocurrency wallet** that uses **Threshold Signature Scheme (TSS)** technology. Unlike traditional wallets: +Vultisig is a self-custodial, seedless crypto wallet using Threshold Signature Scheme (TSS) — no seed phrases, multi-device signing (e.g. 2-of-3), one vault across many chains. -- **No seed phrases**: Instead of a 12/24 word recovery phrase that can be stolen or lost, Vultisig splits your private key across multiple devices using cryptographic secret sharing -- **Multi-device security**: Transactions require signatures from multiple devices (e.g., 2-of-3), so no single compromised device can steal funds -- **Vault-based architecture**: Each "vault" is a collection of key shares across your devices that together control your crypto assets -- **Cross-chain support**: One vault can hold assets across many blockchains +Supported chains: Ethereum, Arbitrum, Avalanche, BNB Chain, Base, Blast, Optimism, Polygon, Bitcoin, Litecoin, Dogecoin, Bitcoin Cash, Dash, Zcash, Solana, XRP, Cosmos, THORChain, MayaChain, Tron. -### Supported Blockchains -**EVM Chains**: Ethereum, Arbitrum, Avalanche, BNB Chain, Base, Blast, Optimism, Polygon -**UTXO Chains**: Bitcoin, Litecoin, Dogecoin, Bitcoin Cash, Dash, Zcash -**Other Chains**: Solana, XRP, Cosmos (Gaia), THORChain, MayaChain, Tron +## Actions -### Key Features -- **Vault Sharing**: Share vault access with family or team members with configurable signing thresholds -- **Plugin System**: Extend wallet functionality with verified plugins for automation -- **THORChain/MayaChain Integration**: Native cross-chain swaps without bridges -- **Hardware-level Security**: Key shares can be stored on separate physical devices +Include actions when the user's request requires the app to do something. You can return MULTIPLE actions in a single response — the app executes them all in parallel. Batch related actions together (e.g. get_market_price + get_balances, or multiple add_coin calls). -## Your Role +` + ActionsTable + ` -You are the conversational interface for Vultisig users. You can: +auto_execute=true actions run immediately in parallel. auto_execute=false actions render as tappable cards for user confirmation. -1. **Answer questions** about Vultisig, crypto, DeFi, and blockchain technology -2. **Detect user intent** when they want to perform actions (DCA, swaps, sends) -3. **Suggest actions** by offering plugin-based automation options via the create_suggestion tool -4. **Guide users** through setting up recurring transactions -5. **Confirm action results** when the user reports success or failure of an action +## Swap Transaction Building -## Workflow +For swaps, ALWAYS use build_tx. This builds the actual unsigned transaction so the user can review exact output amounts and sign directly. -You have tools to help users set up DCA automations. Follow this general approach: +CRITICAL RULES for build_tx params: +- from_contract and from_decimals: copy EXACTLY from "Coins in Vault" context. The "contract" field is from_contract, the "decimals" field is from_decimals. For native tokens (contract = "native"), set from_contract to empty string "". +- to_contract and to_decimals: copy EXACTLY from "Coins in Vault" if the destination token is there. If not in vault, omit both fields (server resolves automatically). +- NEVER fabricate or guess contract addresses. Only use values copied verbatim from context. +- If the source token is NOT in "Coins in Vault", the user doesn't have it — tell them. +- If the destination token is NOT in "Coins in Vault", omit to_contract and to_decimals. +- amount is in HUMAN-READABLE units (e.g. "10" for 10 USDC, NOT base units). The server handles conversion. +- When the user specifies a fiat/dollar amount (e.g., "$10 of ETH", "100 USD worth of BTC"), do NOT put the fiat number in the amount field. Instead: + 1. First call get_market_price for the source token (auto_execute=true) to get the current price. + 2. After receiving the price result, calculate: token_amount = fiat_amount / price. + 3. Compare the calculated token_amount against the user's balance in Balances context. If insufficient, tell the user (e.g. "You only have 0.899 ETH (~$1,786), which isn't enough for a $10,000 swap.") and do NOT call build_tx. + 4. If sufficient, call build_tx with the calculated token amount in the amount field. -1. Understand what the user wants. Ask clarifying questions if needed. -2. When the user wants an automation, use create_suggestion to show them options. -3. When they confirm a suggestion, check if the required plugin is installed using check_plugin_installed. -4. If not installed, tell them they need to install it first and stop. -5. Check billing status using check_billing_status. If the free trial has expired and the billing app is not installed, tell the user they need to install the billing app first and stop. -6. If all prerequisites are met, get the recipe schema using get_recipe_schema, build a configuration, and call suggest_policy. -7. Present the policy details for the user to review and confirm. +## Transaction Confirmation (build_tx result handling) -You MAY skip steps if the user provides complete information upfront. -You MUST always present policy details before creating a policy. -You MUST use correct token addresses and chain names from the schema. +When you receive a successful build_tx action result, you MUST respond with EXACTLY this template and NOTHING else. Copy it verbatim, only replacing the bracketed placeholders. -## Policy Building Instructions +TEMPLATE: "Swap [amount] [FROM] for ~[expected_output] [TO] via [provider][CROSS_CHAIN]. [APPROVAL_LINE]Ready to swap?" -When building a configuration for suggest_policy: +Replace these placeholders: +- [amount]: the amount from the build request +- [FROM]: source token symbol +- [expected_output]: expected_output from the build result +- [TO]: destination token symbol +- [provider]: provider name from the build result +- [CROSS_CHAIN]: if from_chain != to_chain, write " (cross-chain)". Otherwise omit. +- [APPROVAL_LINE]: if needs_approval is true, write "Requires token approval. " (with trailing space). Otherwise omit entirely. -1. Extract relevant parameters from the conversation (amounts, tokens, chains, frequency, etc.) -2. Map them to the plugin's schema fields -3. Use the user's wallet addresses for source addresses -4. For tokens, use the correct token contract addresses (e.g., "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48" for USDC). For native assets (ETH, BTC, etc.), leave the token field as an empty string "" -5. Ensure amounts are in human-readable format (e.g., "10" for 10 USDC, "0.5" for 0.5 ETH) -6. Use the addresses from the user's context for the "from" address -7. Never set a swap amount below ~$5 equivalent, as DEX providers will reject swaps that are too small -8. If the user's balance for the source asset is below ~$5 equivalent, the swap will likely fail — warn the user -9. If no balance information is available, use the user's requested amount but note they should ensure sufficient funds -10. If frequency was discussed, include it -11. If any required field is unclear, make a reasonable default based on the conversation +The response MUST end with "Ready to swap?" — this is the confirmation prompt. -## Action Results +ABSOLUTE RULES: +- Do NOT add any words before or after the template. No "Perfect", "Great", "Your swap is ready", no preamble, no narration. +- Do NOT use exclamation marks, bullet points, contract addresses, gas limits, or technical details. +- Do NOT explain what will happen. Just state the swap and ask "Ready to swap?" -When the user reports an action result (e.g., policy created, plugin installed): -- **For successful actions**: Celebrate briefly and summarize what was accomplished -- **For failed actions**: Be empathetic, explain what went wrong, and offer next steps -- If a plugin was just installed successfully, you can proceed to help them set up their automation +When the user confirms (yes, confirm, do it, go, etc.) → return sign_tx action with auto_execute=true and empty params. +When the user wants changes → adjust and call build_tx again. +When the user cancels → acknowledge briefly, no sign_tx. + +## Send Transaction Building + +When the user wants to send tokens and all required params are known (coin + address + amount), use build_send_tx. This builds the transaction inline so the user can confirm and sign directly in chat. + +build_send_tx params: +- chain: the blockchain name (e.g. "Ethereum", "Bitcoin") +- symbol: the token ticker (e.g. "ETH", "USDC") +- address: the recipient address +- amount: human-readable units (e.g. "0.1" for 0.1 ETH) +- memo: optional memo/tag. All native gas token sends support memos (EVM chains encode it in the tx data field, UTXO chains use OP_RETURN, Cosmos/THORChain use the memo field, etc.). Only omit for ERC20/SPL/other non-native token transfers. + +If any required param (chain/symbol, address, amount) is missing, ask the user for it. Do NOT call build_send_tx until all params are known. + +CRITICAL: ALWAYS check the user's Balances context before calling build_send_tx. If insufficient balance, tell the user and do NOT build. + +## Send Transaction Confirmation + +When you receive a successful build_send_tx action result, respond with EXACTLY this template: + +TEMPLATE: "Send [amount] [SYMBOL] to [truncated_address] on [chain]. Ready to send?" + +Replace: +- [amount]: the amount from the build request +- [SYMBOL]: token symbol +- [truncated_address]: first 6 and last 4 characters of the address (e.g. "0x1234...5678") +- [chain]: the chain name + +ABSOLUTE RULES: +- Do NOT add any words before or after the template. +- Do NOT use exclamation marks, bullet points, or technical details. + +When the user confirms → return sign_tx action with auto_execute=true and empty params. +When the user cancels → acknowledge briefly, no sign_tx. + +## Custom Transaction Building + +Use build_custom_tx for advanced on-chain operations: THORChain/Maya deposits, EVM smart contract calls, and CosmWasm execution. + +build_custom_tx params vary by tx_type: + +### Deposit (tx_type: "deposit") +For THORChain/Maya MsgDeposit operations (bond, unbond, leave, etc.). +- chain: "THORChain" or "MayaChain" +- symbol: the token ticker (e.g. "RUNE", "CACAO") +- amount: human-readable units (e.g. "1000" for 1000 RUNE) +- memo: the THORChain memo (e.g. "BOND:thor1abc...", "UNBOND:thor1abc...:100000000", "LEAVE:thor1abc...") + +CRITICAL: For UNBOND, set amount to "0" — the unbond amount goes ONLY in the memo in base units (1 RUNE = 100000000). Example: "UNBOND:thor1nodeaddr:100000000" unbonds 1 RUNE. + +### EVM Contract Call (tx_type: "evm_contract") +For calling smart contracts on EVM chains (Ethereum, Arbitrum, etc.). +- chain: the EVM chain name (e.g. "Ethereum", "Arbitrum") +- contract_address: the contract address (e.g. "0xa0b86991...") +- function_name: the Solidity function name (e.g. "approve", "transfer") +- params: array of {type, value} objects. Supported types: "address", "uint256", "string", "bytes", "bool" +- value: optional ETH/native token value to send with the call, in human-readable units (default "0") + +Example approve call: tx_type="evm_contract", chain="Ethereum", contract_address="0xUSDC", function_name="approve", params=[{type:"address",value:"0xSpender"},{type:"uint256",value:"1000000"}] + +### CosmWasm Execute (tx_type: "wasm_execute") +For executing CosmWasm smart contracts (e.g. on THORChain). +- chain: the chain name (e.g. "THORChain") +- contract_address: the WASM contract address +- execute_msg: JSON string of the execute message (e.g. '{"stake":{}}') +- funds: optional array of {denom, amount} objects for coins to send with execution. Denoms are lowercase (e.g. "rune"), amounts in base units. + +### Custom Transaction Confirmation + +When you receive a successful build_custom_tx action result, respond with a brief summary of the transaction and ask "Ready to execute?" + +TEMPLATE for deposit: "[ACTION] [amount] [SYMBOL] on [chain]. Ready to execute?" +TEMPLATE for evm_contract: "Call [function_name] on [truncated_contract] on [chain]. Ready to execute?" +TEMPLATE for wasm_execute: "Execute contract [truncated_contract] on [chain]. Ready to execute?" + +When the user confirms → return sign_tx action with auto_execute=true and empty params. +When the user cancels → acknowledge briefly, no sign_tx. + +## THORChain Position Queries + +Use thorchain_query to look up THORChain/Midgard data when the user asks about their THORChain positions, LP, bonds, savers, stakes, or network info. + +thorchain_query params: +- query_type (required): one of "lp_positions", "saver_positions", "bond_positions", "node_details", "pool_info", "rune_pool", "network_info", "stake_positions", "trade_accounts" +- asset (optional): for pool_info, specify the pool asset (e.g. "BTC.BTC", "ETH.ETH") + +The user's THORChain address is auto-resolved. Query mapping: +- "What are my LP positions?" → lp_positions +- "Show my savers" → saver_positions +- "What nodes am I bonded to?" → bond_positions +- "Show my stakes" → stake_positions +- "What's the BTC pool?" → pool_info, asset="BTC.BTC" +- "My RUNE pool position" → rune_pool +- "THORChain network stats" → network_info +- "My trade accounts" → trade_accounts + +## Reading EVM Contract State + +Use read_evm_contract to call read-only (view/pure) functions on EVM smart contracts. This does NOT create a transaction — it's a free eth_call. + +read_evm_contract params: +- chain: the EVM chain name (e.g. "Ethereum", "Arbitrum") +- contract_address: the contract address (e.g. "0xa0b86991...") +- function_name: the Solidity function signature (e.g. "allowance(address,address)", "balanceOf(address)") +- params: array of {type, value} objects matching the function inputs. Supported types: "address", "uint256", "string", "bytes", "bool" +- output_types: array of output type strings (e.g. ["uint256"]). Supported: "address", "uint256", "string", "bytes", "bool" + +Common uses: +- Check ERC20 allowance: function_name="allowance(address,address)", params=[{type:"address",value:"OWNER"},{type:"address",value:"SPENDER"}], output_types=["uint256"] +- Check ERC20 balance: function_name="balanceOf(address)", params=[{type:"address",value:"HOLDER"}], output_types=["uint256"] + +The result is returned as action data with decoded output values. Use the user's address from Addresses context as the owner/holder. + +## Token Discovery + +When the user mentions a token NOT present in their vault (check "Coins in Vault" context), use search_token to discover it before adding or swapping: +1. Call search_token with the token name/ticker (and optional chain filter). +2. Always present the results as a list showing: chain, symbol, logo, price, contract_address. Let the user review the candidates. +3. If only one result → still show it for confirmation, do NOT auto-add. +4. After the user picks one, use add_coin with the full details (chain, ticker, contract_address, decimals, logo, price_provider_id) from the search result. +5. Then proceed with the original request (swap, etc.). + +Never guess or fabricate contract addresses — ONLY use values from "Coins in Vault" context or search_token results. +Never auto-add tokens from search results — always let the user review and confirm first. ## Guidelines -1. **Be concise**: Users are on mobile devices. Keep responses brief but helpful. -2. **Be specific**: When suggesting actions, include concrete details based on user's balances. -3. **Be balance-aware**: Always check the user's balances before suggesting swap or send amounts. If a balance is too low (under ~$5 equivalent) for the source asset, warn the user that the swap may fail due to provider minimums. If no balances are provided, ask the user to confirm they have sufficient funds and suggest at least $5 equivalent as a starting point. -4. **Be security-conscious**: Remind users about best practices when relevant. -5. **Ask clarifying questions** if the user's intent is unclear. -6. **Stay in scope**: For actions outside your capabilities, explain what Vultisig can do instead. -7. **Don't fabricate**: Only state facts about Vultisig that are provided in this prompt. If you're unsure about something Vultisig-specific (tokenomics, roadmap, partnerships, etc.), say you don't have that information and suggest checking the official Vultisig website or community channels.` +- Be extremely concise — 1-2 short sentences max. Users are on mobile. No narration, no restating what the user said, no "I'll set up..." preamble. Just act. +- When you can fulfill the request, do it immediately with minimal commentary. Only explain if something is wrong (insufficient balance, ambiguity). +- Fix obvious typos silently (e.g. "USDDC" → USDC). Don't ask for confirmation on obvious corrections. +- ALWAYS check the user's Balances context before calling build_tx or build_send_tx. Use the balance for the specific chain (e.g. "USDT on Ethereum"), NOT the sum across all chains. If insufficient, tell the user and do NOT build. Warn if source balance is under ~$5 (DEX minimums for swaps). +- When a token exists on multiple chains, auto-select the chain where the user has sufficient balance. Only ask which chain if MULTIPLE chains have enough balance for the requested amount. +- Use auto_execute=true actions for price/balance/portfolio queries. +- Use suggestions for recurring automation (DCA, scheduled swaps) to guide plugin-based policy flow. +- Only ask clarifying questions when genuinely ambiguous — not for typos or when a reasonable default exists. +- Don't fabricate Vultisig-specific facts (tokenomics, roadmap, etc.) — suggest checking official channels. +- Always respond via the respond_to_user tool. + +## Scheduled Tasks + +You can schedule tasks to run later or on a recurring basis using ` + "`schedule_task`" + `. + +When to use scheduling: +- "Check ETH price every Monday" → recurring, next_run_at = next Monday 9am UTC, interval_seconds = 604800 +- "Remind me to check my portfolio tomorrow at 9am" → one-time, next_run_at = tomorrow 9am UTC, no interval +- "Alert me every 6 hours if BTC drops below $50k" → recurring, next_run_at = 6 hours from now, interval_seconds = 21600 + +Common intervals: 3600 = 1 hour, 21600 = 6 hours, 86400 = 1 day, 604800 = 1 week. + +All times are UTC. The current UTC time is included in your context — use it to compute next_run_at accurately. + +CRITICAL: When creating a task, freeze ALL parameters into the context field. The task runs without access to conversation history, so it must be fully self-contained. Include: asset names, chains, amounts, thresholds, addresses — everything needed. + +Scheduled tasks can observe (check prices, balances) and prepare actions (suggest policies, create suggestions) but CANNOT execute transactions directly. The user must approve any actions when they review the result. + +Use ` + "`list_scheduled_tasks`" + ` to show the user their tasks (returns IDs for reference). +Use ` + "`update_scheduled_task`" + ` or ` + "`cancel_scheduled_task`" + ` with a task ID from the list. + +## Send & Address Resolution + +You have the user's address book in context. A contact may have multiple entries for different chains (e.g. Ed on Ethereum and Ed on Bitcoin). When they mention a contact name: +- Look up ALL entries for that name, then pick the one matching the chain of the coin being sent. E.g. "send USDC to Ed" → USDC is on Ethereum → use Ed's Ethereum address. +- Found matching chain entry → use the resolved address with build_send_tx +- Contact exists but not on the required chain → tell the user you don't have their address for that chain +- Not found at all → tell the user the contact wasn't found, ask for the address + +For vault-to-vault sends ("send to my other vault"): +- Use list_vaults action (auto_execute=true) to look up vault addresses +- After receiving vault data, determine the correct chain address and call build_send_tx + +If any required param (coin, address, amount) is missing, ask the user for it — do NOT call build_send_tx until all params are known.` + +// RespondToUserTool is the tool definition for responding to users. +var RespondToUserTool = ai.Tool{ + Name: "respond_to_user", + Description: "Respond to the user with detected intent, optional suggestions for plugin-based automation, and optional actions for the app to execute.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "intent": map[string]any{ + "type": "string", + "enum": []string{"action_request", "general_question", "unclear"}, + "description": "The detected user intent: 'action_request' for DCA/swap/send requests, 'general_question' for informational queries, 'unclear' when more context is needed.", + }, + "conversation_title": map[string]any{ + "type": "string", + "description": "A short (3-6 word) title summarising what this conversation is about. Generate on every response. Update if the conversation topic changes.", + }, + "response": map[string]any{ + "type": "string", + "description": "The response text to show the user.", + }, + "suggestions": map[string]any{ + "type": "array", + "description": "Optional plugin-based action suggestions. Only include for recurring automation requests (DCA, scheduled swaps).", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "plugin_id": map[string]any{ + "type": "string", + "description": "The plugin ID that can handle this action.", + }, + "title": map[string]any{ + "type": "string", + "description": "A short, descriptive title for the suggestion (e.g., 'Weekly DCA into ETH').", + }, + "description": map[string]any{ + "type": "string", + "description": "A brief description of what this suggestion will do.", + }, + }, + "required": []string{"plugin_id", "title", "description"}, + }, + }, + "actions": map[string]any{ + "type": "array", + "description": "Optional actions for the app to execute. Use for immediate operations like fetching prices, adding chains, initiating sends/swaps, etc.", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "type": map[string]any{ + "type": "string", + "enum": []string{ + "get_market_price", "get_balances", "get_portfolio", + "add_chain", "add_coin", "search_token", "remove_coin", "remove_chain", + "build_send_tx", "build_custom_tx", "build_tx", "sign_tx", "scan_tx", + "read_evm_contract", + "thorchain_query", + "plugin_install", "create_policy", "delete_policy", + "address_book_add", "address_book_remove", + "list_vaults", + }, + "description": "The action type to execute.", + }, + "title": map[string]any{ + "type": "string", + "description": "A short, human-readable title for the action card (e.g., 'Fetch ETH Price').", + }, + "description": map[string]any{ + "type": "string", + "description": "Optional description shown on the action card.", + }, + "params": map[string]any{ + "type": "object", + "description": "Parameters for the action. Keys depend on the action type.", + "additionalProperties": true, + }, + "auto_execute": map[string]any{ + "type": "boolean", + "description": "If true, the app executes this action immediately without user interaction. Use for read-only actions (get_market_price, get_balances, get_portfolio). Default false.", + }, + }, + "required": []string{"type", "title"}, + }, + }, + }, + "required": []string{"intent", "conversation_title", "response"}, + }, +} + +// ConfirmActionTool is the tool definition for confirming action results. +var ConfirmActionTool = ai.Tool{ + Name: "confirm_action", + Description: "Generate a confirmation message for a completed action (success or failure).", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "response": map[string]any{ + "type": "string", + "description": "A friendly, concise message confirming the action result. For success: celebrate and summarize what was set up. For failure: explain what went wrong and offer help.", + }, + "next_steps": map[string]any{ + "type": "array", + "description": "Optional list of suggested next actions the user might want to take.", + "items": map[string]any{ + "type": "string", + }, + }, + }, + "required": []string{"response"}, + }, +} + +// BuildPolicyTool is the tool definition for building policy configurations. +var BuildPolicyTool = ai.Tool{ + Name: "build_policy", + Description: "Build a policy configuration based on the user's conversation and the plugin's schema.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "configuration": map[string]any{ + "type": "object", + "description": "The configuration object matching the plugin's RecipeSchema. Include all required fields based on conversation context.", + "additionalProperties": true, + }, + "explanation": map[string]any{ + "type": "string", + "description": "A brief human-readable explanation of what was configured.", + }, + }, + "required": []string{"configuration", "explanation"}, + }, +} // PluginSkill represents a plugin's capabilities loaded from skills.md type PluginSkill struct { @@ -96,13 +404,7 @@ type PluginSkill struct { } // SummarizationPrompt is the prompt used to summarize older conversation messages. -const SummarizationPrompt = `Summarize the following conversation between a user and the Vultisig AI assistant. Focus on: -- Key user intents and requests -- Important decisions made -- Assets, amounts, chains, and addresses mentioned -- Actions taken or pending - -Be concise but preserve all actionable details. This summary will be used as context for future messages.` +const SummarizationPrompt = `Summarize this conversation concisely. Preserve: user intents, decisions, assets/amounts/chains/addresses mentioned, and actions taken or pending. This summary provides context for future messages.` // BuildSystemPromptWithSummary appends an earlier conversation summary to the base system prompt. func BuildSystemPromptWithSummary(basePrompt string, summary *string) string { @@ -112,12 +414,108 @@ func BuildSystemPromptWithSummary(basePrompt string, summary *string) string { return basePrompt + "\n\n## Earlier Conversation Summary\n\n" + *summary } +// walletContextOpts controls minor rendering differences between prompt contexts. +type walletContextOpts struct { + includeAssetInBalances bool + addressHeader string + includeNativeTag bool + includeAddressBookHint bool +} + +func writeWalletContext(sb *strings.Builder, msgCtx *MessageContext, opts walletContextOpts) { + if msgCtx == nil { + return + } + + sb.WriteString("\n\n## User's Wallet Context\n") + + if msgCtx.VaultName != "" { + sb.WriteString("\n### Vault Name\n") + sb.WriteString(msgCtx.VaultName) + sb.WriteString("\n") + } + + if len(msgCtx.Balances) > 0 { + sb.WriteString("\n### Balances\n") + for _, b := range msgCtx.Balances { + sb.WriteString("- ") + sb.WriteString(b.Symbol) + sb.WriteString(" on ") + sb.WriteString(b.Chain) + sb.WriteString(": ") + sb.WriteString(b.Amount) + if opts.includeAssetInBalances { + sb.WriteString(" (") + sb.WriteString(b.Asset) + sb.WriteString(")") + } + sb.WriteString("\n") + } + } + + if len(msgCtx.Addresses) > 0 { + sb.WriteString("\n### ") + sb.WriteString(opts.addressHeader) + sb.WriteString("\n") + for chain, addr := range msgCtx.Addresses { + sb.WriteString("- ") + sb.WriteString(chain) + sb.WriteString(": ") + sb.WriteString(addr) + sb.WriteString("\n") + } + } + + if len(msgCtx.Coins) > 0 { + sb.WriteString("\n### Coins in Vault\n") + for _, coin := range msgCtx.Coins { + sb.WriteString("- ") + sb.WriteString(coin.Ticker) + sb.WriteString(" on ") + sb.WriteString(coin.Chain) + sb.WriteString(" (contract: ") + if coin.ContractAddress != "" { + sb.WriteString(coin.ContractAddress) + } else { + sb.WriteString("native") + } + sb.WriteString(", decimals: ") + sb.WriteString(fmt.Sprintf("%d", coin.Decimals)) + sb.WriteString(")") + if opts.includeNativeTag && coin.IsNativeToken { + sb.WriteString(" [native]") + } + sb.WriteString("\n") + } + } + + if len(msgCtx.AddressBook) > 0 { + sb.WriteString("\n### Address Book\n") + if opts.includeAddressBookHint { + sb.WriteString("You have the user's address book. When they refer to contacts by name, resolve the address directly.\n") + } + for _, entry := range msgCtx.AddressBook { + sb.WriteString("- ") + sb.WriteString(entry.Title) + sb.WriteString(": ") + sb.WriteString(entry.Address) + sb.WriteString(" (") + sb.WriteString(entry.Chain) + sb.WriteString(")\n") + } + } +} + // BuildFullPrompt constructs the complete system prompt with context and plugin skills. -func BuildFullPrompt(balances []Balance, addresses map[string]string, plugins []PluginSkill) string { +func BuildFullPrompt(msgCtx *MessageContext, plugins []PluginSkill) string { var sb strings.Builder sb.WriteString(SystemPrompt) - // Add plugin skills + // Inject current UTC time so Claude can compute next_run_at for scheduled tasks + sb.WriteString("\n\n## Current Time\n\n") + sb.WriteString(time.Now().UTC().Format(time.RFC3339)) + sb.WriteString("\n") + if len(plugins) > 0 { sb.WriteString("\n\n## Available Plugins\n\n") sb.WriteString("The following plugins are available for automation. When users express intent matching a plugin's capabilities, suggest using that plugin.\n") @@ -132,40 +530,30 @@ func BuildFullPrompt(balances []Balance, addresses map[string]string, plugins [] } } - // Add user wallet context - if len(balances) > 0 || len(addresses) > 0 { - sb.WriteString("\n\n## User's Wallet Context\n") - - if len(balances) > 0 { - sb.WriteString("\n### Balances\n") - for _, b := range balances { - sb.WriteString("- ") - sb.WriteString(b.Symbol) - sb.WriteString(" on ") - sb.WriteString(b.Chain) - sb.WriteString(": ") - sb.WriteString(b.Amount) - sb.WriteString("\n") - } - } - - if len(addresses) > 0 { - sb.WriteString("\n### Addresses\n") - for chain, addr := range addresses { - sb.WriteString("- ") - sb.WriteString(chain) - sb.WriteString(": ") - sb.WriteString(addr) - sb.WriteString("\n") - } - } - } + writeWalletContext(&sb, msgCtx, walletContextOpts{ + addressHeader: "Addresses", + includeNativeTag: true, + includeAddressBookHint: true, + }) return sb.String() } +// BuildVaultInfoSection returns a system prompt section describing the active vault. +func BuildVaultInfoSection(v *types.VaultInfo) string { + if v == nil { + return "" + } + return "\n\n## Active Vault\n\n" + + "This conversation has a vault bound to it. The MCP server has been primed with these keys " + + "so tools like get_eth_balance and get_token_balance will derive addresses automatically.\n\n" + + "- ECDSA public key: `" + v.ECDSAPublicKey + "`\n" + + "- EdDSA public key: `" + v.EDDSAPublicKey + "`\n" + + "- Chaincode: `" + v.ChaincodeHex + "`\n" +} + // UpdateMemoryTool is the tool definition for updating the user's memory document. -var UpdateMemoryTool = anthropic.Tool{ +var UpdateMemoryTool = ai.Tool{ Name: "update_memory", Description: "Update your persistent memory about this user. Send the COMPLETE " + "updated memory document (markdown). This replaces the entire document. " + @@ -182,30 +570,14 @@ var UpdateMemoryTool = anthropic.Tool{ }, } -// MemoryManagementInstructions is appended to the system prompt when memory is available. +// MemoryManagementInstructions is appended to the system prompt for Ability 1 only. const MemoryManagementInstructions = ` -## Memory Management +## Memory -You have a persistent memory document about this user that survives across conversations. You can update it anytime using the ` + "`update_memory`" + ` tool. +You have a persistent memory document about this user (survives across conversations). Update it via ` + "`update_memory`" + ` when the user shares preferences, personal info, or strategies worth remembering. Don't update for greetings, transient chat, or data already available from the app. -### When to Update -- User shares a preference ("I prefer weekly DCA", "I like ETH") -- User reveals personal info ("My name is Alex") -- User describes their strategy ("I only DCA into top 10 coins") -- You learn something that would help in future conversations -- An action completes (policy created, plugin installed) - -### When NOT to Update -- Trivial greetings or transient chat -- Information already in your memory document -- Data available from the app (balances, addresses, prices) - -### How to Update -- Send the COMPLETE updated document — it replaces the entire memory -- Keep it under 4000 characters -- Organize naturally using markdown sections -- Remove outdated information when updating` +` + "`update_memory`" + ` replaces the entire document — send the COMPLETE updated version (max 4000 chars, markdown). Always include ` + "`respond_to_user`" + ` alongside it.` // BuildMemorySection wraps the user's memory document content for injection into system prompts. // Returns empty string if content is empty. @@ -214,8 +586,6 @@ func BuildMemorySection(content string) string { return "" } - return "\n\n## Your Memories About This User\n\n" + - "This is your persistent memory document about this user. Use it to personalize\n" + - "your responses naturally — don't repeat it back unless relevant.\n\n" + - content + return "\n\n## User Memory\n\n" + content } + diff --git a/internal/service/agent/prompt_test.go b/internal/service/agent/prompt_test.go new file mode 100644 index 0000000..6f0ce62 --- /dev/null +++ b/internal/service/agent/prompt_test.go @@ -0,0 +1,91 @@ +package agent + +import ( + "strings" + "testing" +) + +func TestBuildFullPrompt(t *testing.T) { + t.Run("nil context", func(t *testing.T) { + got := BuildFullPrompt(nil, nil) + if !strings.Contains(got, "Vultisig AI assistant") { + t.Error("expected system prompt") + } + }) + + t.Run("with balances", func(t *testing.T) { + ctx := &MessageContext{ + Balances: []Balance{ + {Chain: "Ethereum", Symbol: "ETH", Amount: "1.5"}, + }, + } + got := BuildFullPrompt(ctx, nil) + if !strings.Contains(got, "ETH on Ethereum: 1.5") { + t.Errorf("expected balance in prompt, got:\n%s", got) + } + }) + + t.Run("with plugins", func(t *testing.T) { + plugins := []PluginSkill{ + {PluginID: "dca-plugin", Name: "DCA", Skills: "Dollar cost average into tokens"}, + } + got := BuildFullPrompt(nil, plugins) + if !strings.Contains(got, "DCA (dca-plugin)") { + t.Errorf("expected plugin in prompt, got:\n%s", got) + } + }) + + t.Run("with address book", func(t *testing.T) { + ctx := &MessageContext{ + AddressBook: []AddressBookEntry{ + {Title: "Alice", Address: "0xabc", Chain: "Ethereum"}, + }, + } + got := BuildFullPrompt(ctx, nil) + if !strings.Contains(got, "Alice: 0xabc (Ethereum)") { + t.Errorf("expected address book entry in prompt") + } + if !strings.Contains(got, "resolve the address directly") { + t.Errorf("expected address book hint in full prompt") + } + }) +} + +func TestBuildSystemPromptWithSummary(t *testing.T) { + t.Run("nil summary", func(t *testing.T) { + got := BuildSystemPromptWithSummary("base", nil) + if got != "base" { + t.Errorf("got %q, want 'base'", got) + } + }) + + t.Run("with summary", func(t *testing.T) { + summary := "User discussed DCA into ETH" + got := BuildSystemPromptWithSummary("base", &summary) + if !strings.Contains(got, "Earlier Conversation Summary") { + t.Error("expected summary header") + } + if !strings.Contains(got, "DCA into ETH") { + t.Error("expected summary content") + } + }) +} + +func TestBuildMemorySection(t *testing.T) { + t.Run("empty", func(t *testing.T) { + got := BuildMemorySection("") + if got != "" { + t.Errorf("expected empty, got %q", got) + } + }) + + t.Run("with content", func(t *testing.T) { + got := BuildMemorySection("User prefers ETH") + if !strings.Contains(got, "User Memory") { + t.Error("expected memory header") + } + if !strings.Contains(got, "User prefers ETH") { + t.Error("expected memory content") + } + }) +} diff --git a/internal/service/agent/starters.go b/internal/service/agent/starters.go new file mode 100644 index 0000000..afb5479 --- /dev/null +++ b/internal/service/agent/starters.go @@ -0,0 +1,115 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/vultisig/agent-backend/internal/ai" +) + +const ( + startersTTL = 15 * time.Minute + startersTimeout = 20 * time.Second +) + +var StartersPrompt = `Generate exactly 8 conversation starter messages for a crypto wallet AI assistant. These are clickable prompts shown to the user when they open the chat. + +## Available Capabilities + +` + ActionsTable + ` + +## Rules + +- Each starter must be a natural, first-person request the user would type +- Starters MUST reference the user's actual coins, balances, and chains from the wallet context below +- Each starter should map to one of the available capabilities above +- Use realistic amounts based on the user's actual balances (never more than they hold) +- Mix of simple queries (price, balance) and actionable requests (swap, send, DCA) +- Keep each starter under 60 characters +- Do NOT suggest actions involving chains or coins the user doesn't have +- Do NOT include greetings, questions about capabilities, or meta-conversation +- Respond with ONLY a JSON array of 8 strings, no other text` + +func startersCacheKey(publicKey string) string { + return fmt.Sprintf("starters:%s", publicKey) +} + +func (s *AgentService) GenerateStarters(ctx context.Context, req *GetStartersRequest) *GetStartersResponse { + empty := &GetStartersResponse{Starters: []string{}} + + if req.PublicKey == "" { + return empty + } + + cacheKey := startersCacheKey(req.PublicKey) + cached, err := s.redis.Get(ctx, cacheKey) + if err == nil && cached != "" { + var starters []string + parseErr := json.Unmarshal([]byte(cached), &starters) + if parseErr == nil && len(starters) > 0 { + return &GetStartersResponse{Starters: starters} + } + } + + var sb strings.Builder + sb.WriteString(StartersPrompt) + + writeWalletContext(&sb, req.Context, walletContextOpts{ + includeAssetInBalances: false, + addressHeader: "Addresses", + includeNativeTag: false, + includeAddressBookHint: false, + }) + + aiCtx, cancel := context.WithTimeout(ctx, startersTimeout) + defer cancel() + + aiReq := &ai.Request{ + Model: s.summaryModel, + MaxTokens: 1024, + Messages: []any{ + ai.Message{Role: "user", Content: sb.String()}, + }, + } + + resp, err := s.ai.SendMessage(aiCtx, aiReq) + if err != nil { + s.logger.WithError(err).Warn("failed to generate starters") + return empty + } + + text := resp.Content + if text == "" { + return empty + } + + text = strings.TrimSpace(text) + text = strings.TrimPrefix(text, "```json") + text = strings.TrimPrefix(text, "```") + text = strings.TrimSuffix(text, "```") + text = strings.TrimSpace(text) + + var starters []string + err = json.Unmarshal([]byte(text), &starters) + if err != nil { + s.logger.WithError(err).WithField("text", text).Warn("failed to parse starters response") + return empty + } + + if len(starters) == 0 { + return empty + } + + cacheData, marshalErr := json.Marshal(starters) + if marshalErr == nil { + cacheErr := s.redis.Set(ctx, cacheKey, string(cacheData), startersTTL) + if cacheErr != nil { + s.logger.WithError(cacheErr).Warn("failed to cache starters") + } + } + + return &GetStartersResponse{Starters: starters} +} diff --git a/internal/service/agent/tools.go b/internal/service/agent/tools.go index 471da84..c807121 100644 --- a/internal/service/agent/tools.go +++ b/internal/service/agent/tools.go @@ -1,9 +1,9 @@ package agent -import "github.com/vultisig/agent-backend/internal/ai/anthropic" +import "github.com/vultisig/agent-backend/internal/ai" // CheckPluginInstalledTool checks if a plugin is installed for the user's vault. -var CheckPluginInstalledTool = anthropic.Tool{ +var CheckPluginInstalledTool = ai.Tool{ Name: "check_plugin_installed", Description: "Check if a specific plugin is installed for the user's vault. " + "Call this when the user wants to use a plugin's features (e.g., create a DCA policy). " + @@ -21,7 +21,7 @@ var CheckPluginInstalledTool = anthropic.Tool{ } // GetRecipeSchemaTool fetches the configuration schema and examples for a plugin. -var GetRecipeSchemaTool = anthropic.Tool{ +var GetRecipeSchemaTool = ai.Tool{ Name: "get_recipe_schema", Description: "Fetch the configuration schema and examples for a plugin. " + "Use this to understand what fields a plugin requires before building a configuration. " + @@ -39,7 +39,7 @@ var GetRecipeSchemaTool = anthropic.Tool{ } // SuggestPolicyTool validates a configuration and gets policy rules from the verifier. -var SuggestPolicyTool = anthropic.Tool{ +var SuggestPolicyTool = ai.Tool{ Name: "suggest_policy", Description: "Validate a configuration and get policy rules from the verifier. " + "Call this ONLY after you have a complete configuration from get_recipe_schema. " + @@ -63,7 +63,7 @@ var SuggestPolicyTool = anthropic.Tool{ } // CreateSuggestionTool stores a suggestion card for the frontend to display. -var CreateSuggestionTool = anthropic.Tool{ +var CreateSuggestionTool = ai.Tool{ Name: "create_suggestion", Description: "Create a suggestion card for the user to select. " + "Use this when you want to offer the user an action option (e.g., 'Weekly DCA into ETH'). " + @@ -90,7 +90,7 @@ var CreateSuggestionTool = anthropic.Tool{ } // CheckBillingStatusTool checks if the user's billing app is set up. -var CheckBillingStatusTool = anthropic.Tool{ +var CheckBillingStatusTool = ai.Tool{ Name: "check_billing_status", Description: "Check if the user has the billing app installed or has an active free trial. " + "Most plugins require the billing app (vultisig-fees-feee) to be installed after the 7-day free trial expires. " + @@ -102,13 +102,169 @@ var CheckBillingStatusTool = anthropic.Tool{ }, } +// SetVaultTool sets the active vault for this conversation. +var SetVaultTool = ai.Tool{ + Name: "set_vault", + Description: "Set the active vault for this conversation. " + + "Call this when the user provides their vault's public keys and chaincode. " + + "This binds the vault to the conversation so tools like get_eth_balance " + + "and get_token_balance can derive addresses automatically. " + + "The user may switch vaults during a conversation by calling this again.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "ecdsa_public_key": map[string]any{ + "type": "string", + "description": "The vault's ECDSA (secp256k1) public key in hex.", + }, + "eddsa_public_key": map[string]any{ + "type": "string", + "description": "The vault's EdDSA (ed25519) public key in hex.", + }, + "chain_code": map[string]any{ + "type": "string", + "description": "The vault's chaincode in hex, used for key derivation.", + }, + }, + "required": []string{"ecdsa_public_key", "eddsa_public_key", "chain_code"}, + }, +} + +// GetSkillTool loads a specific skill's full instructions on demand. +// Added to the tool list dynamically only when skills are available from MCP. +var GetSkillTool = ai.Tool{ + Name: "get_skill", + Description: "Load the full instructions for a specific skill. " + + "Use this when you identify a skill from the Available Skills list that is relevant to the user's request. " + + "Only load skills that are directly needed — do not speculatively load skills.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "skill_name": map[string]any{ + "type": "string", + "description": "The slug name of the skill to load (as listed in Available Skills).", + }, + }, + "required": []string{"skill_name"}, + }, +} + +// ScheduleTaskTool creates a new scheduled task. +var ScheduleTaskTool = ai.Tool{ + Name: "schedule_task", + Description: "Schedule a task to run later or on a recurring basis. " + + "Use this when the user wants to set up reminders, recurring price checks, " + + "periodic portfolio reviews, or any deferred action. " + + "Freeze ALL necessary context (asset names, amounts, thresholds, addresses) " + + "into the context field so the task is fully self-contained.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "intent": map[string]any{ + "type": "string", + "description": "Natural language instruction for what the task should do when it runs. Be specific and actionable.", + }, + "context": map[string]any{ + "type": "object", + "description": "Frozen parameters needed for execution. Include everything: asset names, chains, amounts, thresholds, addresses. This is opaque to the system — you decide what to freeze.", + "additionalProperties": true, + }, + "next_run_at": map[string]any{ + "type": "string", + "description": "ISO 8601 UTC timestamp for when the task should first run (e.g., '2025-03-01T09:00:00Z'). The current time is provided in your system prompt.", + }, + "interval_seconds": map[string]any{ + "type": "integer", + "description": "For recurring tasks: interval in seconds between runs. Omit for one-time tasks. Examples: 3600 = every hour, 86400 = every day, 604800 = every week.", + }, + "max_runs": map[string]any{ + "type": "integer", + "description": "Maximum number of times to run (for recurring tasks). Omit for unlimited.", + }, + }, + "required": []string{"intent", "context", "next_run_at"}, + }, +} + +// ListScheduledTasksTool lists the user's active scheduled tasks. +var ListScheduledTasksTool = ai.Tool{ + Name: "list_scheduled_tasks", + Description: "List the user's active scheduled tasks. Returns task IDs, intents, " + + "schedules, and context so you can reference them for updates or cancellation.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, +} + +// UpdateScheduledTaskTool modifies an existing scheduled task. +var UpdateScheduledTaskTool = ai.Tool{ + Name: "update_scheduled_task", + Description: "Update an existing scheduled task. Use list_scheduled_tasks first to get the task ID. " + + "Only include fields you want to change.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "task_id": map[string]any{ + "type": "string", + "description": "The UUID of the task to update.", + }, + "intent": map[string]any{ + "type": "string", + "description": "Updated intent text.", + }, + "context": map[string]any{ + "type": "object", + "description": "Updated frozen context. Replaces the entire context.", + "additionalProperties": true, + }, + "next_run_at": map[string]any{ + "type": "string", + "description": "Updated next run time (ISO 8601 UTC).", + }, + "interval_seconds": map[string]any{ + "type": "integer", + "description": "Updated interval in seconds between runs.", + }, + "max_runs": map[string]any{ + "type": "integer", + "description": "Updated max runs.", + }, + }, + "required": []string{"task_id"}, + }, +} + +// CancelScheduledTaskTool cancels an existing scheduled task. +var CancelScheduledTaskTool = ai.Tool{ + Name: "cancel_scheduled_task", + Description: "Cancel an active scheduled task by its ID. " + + "Use list_scheduled_tasks first to find the task ID.", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "task_id": map[string]any{ + "type": "string", + "description": "The UUID of the task to cancel.", + }, + }, + "required": []string{"task_id"}, + }, +} + // agentTools returns all granular tools for the decision loop. -func agentTools() []anthropic.Tool { - return []anthropic.Tool{ +func agentTools() []ai.Tool { + return []ai.Tool{ + RespondToUserTool, CheckPluginInstalledTool, CheckBillingStatusTool, GetRecipeSchemaTool, SuggestPolicyTool, CreateSuggestionTool, + SetVaultTool, + ScheduleTaskTool, + ListScheduledTasksTool, + UpdateScheduledTaskTool, + CancelScheduledTaskTool, } } diff --git a/internal/service/agent/types.go b/internal/service/agent/types.go index 738d3fd..b7ed301 100644 --- a/internal/service/agent/types.go +++ b/internal/service/agent/types.go @@ -1,6 +1,8 @@ package agent import ( + "encoding/json" + "github.com/vultisig/agent-backend/internal/types" ) @@ -8,6 +10,7 @@ import ( type SendMessageRequest struct { PublicKey string `json:"public_key"` Content string `json:"content"` + Model string `json:"model,omitempty"` Context *MessageContext `json:"context,omitempty"` SelectedSuggestionID *string `json:"selected_suggestion_id,omitempty"` ActionResult *ActionResult `json:"action_result,omitempty"` @@ -16,9 +19,27 @@ type SendMessageRequest struct { // MessageContext provides context about the user's wallet state. type MessageContext struct { - VaultAddress string `json:"vault_address,omitempty"` - Balances []Balance `json:"balances,omitempty"` - Addresses map[string]string `json:"addresses,omitempty"` + VaultAddress string `json:"vault_address,omitempty"` + VaultName string `json:"vault_name,omitempty"` + Balances []Balance `json:"balances,omitempty"` + Addresses map[string]string `json:"addresses,omitempty"` + Coins []CoinInfo `json:"coins,omitempty"` + AddressBook []AddressBookEntry `json:"address_book,omitempty"` +} + +type CoinInfo struct { + Chain string `json:"chain"` + Ticker string `json:"ticker"` + ContractAddress string `json:"contract_address,omitempty"` + IsNativeToken bool `json:"is_native_token"` + Decimals int `json:"decimals"` + Logo string `json:"logo,omitempty"` +} + +type AddressBookEntry struct { + Title string `json:"title"` + Address string `json:"address"` + Chain string `json:"chain"` } // Balance represents a token balance in the user's wallet. @@ -30,19 +51,66 @@ type Balance struct { Decimals int `json:"decimals"` } +// Action represents an instruction from the agent for the app to execute. +type Action struct { + ID string `json:"id"` + Type string `json:"type"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Params map[string]any `json:"params,omitempty"` + AutoExecute bool `json:"auto_execute"` +} + // ActionResult contains the result of a user action. type ActionResult struct { - Action string `json:"action"` - Success bool `json:"success"` - Error string `json:"error,omitempty"` + Action string `json:"action"` + ActionID string `json:"action_id,omitempty"` + Success bool `json:"success"` + Data map[string]any `json:"data,omitempty"` + Error string `json:"error,omitempty"` } // SendMessageResponse is the response for sending a message. type SendMessageResponse struct { - Message types.Message `json:"message"` - Suggestions []Suggestion `json:"suggestions,omitempty"` - PolicyReady *PolicyReady `json:"policy_ready,omitempty"` - InstallRequired *InstallRequired `json:"install_required,omitempty"` + Message types.Message `json:"message"` + Title *string `json:"title,omitempty"` + Suggestions []Suggestion `json:"suggestions,omitempty"` + Actions []Action `json:"actions,omitempty"` + PolicyReady *PolicyReady `json:"policy_ready,omitempty"` + InstallRequired *InstallRequired `json:"install_required,omitempty"` + TxReady *TxReady `json:"tx_ready,omitempty"` + Transactions []Transaction `json:"transactions,omitempty"` + Tokens *TokenSearchResult `json:"tokens,omitempty"` +} + +// Transaction represents an unsigned transaction returned by an MCP tool +// that the wallet must sign and broadcast. +type Transaction struct { + Sequence int `json:"sequence"` + Chain string `json:"chain"` + ChainID string `json:"chain_id"` + Action string `json:"action"` + SigningMode string `json:"signing_mode"` + UnsignedTxHex string `json:"unsigned_tx_hex"` + TxDetails map[string]string `json:"tx_details"` +} + +type TxReady struct { + Provider string `json:"provider"` + ExpectedOutput string `json:"expected_output"` + MinimumOutput string `json:"minimum_output"` + NeedsApproval bool `json:"needs_approval"` + ApprovalTx json.RawMessage `json:"approval_tx,omitempty"` + SwapTx json.RawMessage `json:"swap_tx"` + FromChain string `json:"from_chain"` + FromSymbol string `json:"from_symbol"` + FromDecimals int `json:"from_decimals"` + ToChain string `json:"to_chain"` + ToSymbol string `json:"to_symbol"` + ToDecimals int `json:"to_decimals"` + Amount string `json:"amount"` + Sender string `json:"sender"` + Destination string `json:"destination"` } // InstallRequired signals that a plugin must be installed before proceeding. @@ -56,7 +124,7 @@ type InstallRequired struct { type PolicyReady struct { PluginID string `json:"plugin_id"` Configuration map[string]any `json:"configuration"` - PolicySuggest any `json:"policy_suggest"` // verifier.PolicySuggest + PolicySuggest any `json:"policy_suggest"` } // Suggestion represents an action suggestion for the user. @@ -67,3 +135,100 @@ type Suggestion struct { Description string `json:"description"` } +type BuildTxRequest struct { + PublicKey string `json:"public_key"` + Params map[string]any `json:"params"` + Context *MessageContext `json:"context,omitempty"` +} + +type BuildTxResponse struct { + TxReady *TxReady `json:"tx_ready,omitempty"` + Actions []Action `json:"actions,omitempty"` + Message string `json:"message,omitempty"` + Error string `json:"error,omitempty"` +} + +type SSEEvent struct { + Event string `json:"event"` + Data any `json:"data"` +} + +type TextDeltaPayload struct { + Delta string `json:"delta"` +} + +type TitlePayload struct { + Title string `json:"title"` +} + +type SuggestionsPayload struct { + Suggestions []Suggestion `json:"suggestions"` +} + +type ActionsPayload struct { + Actions []Action `json:"actions"` +} + +type MessagePayload struct { + Message types.Message `json:"message"` +} + +type ErrorPayload struct { + Error string `json:"error"` +} + +type GetStartersRequest struct { + PublicKey string `json:"public_key"` + Context *MessageContext `json:"context,omitempty"` +} + +type GetStartersResponse struct { + Starters []string `json:"starters"` +} + +// ToolResponse is the parsed response from the respond_to_user tool. +type ToolResponse struct { + Intent string `json:"intent"` + ConversationTitle string `json:"conversation_title"` + Response string `json:"response"` + Suggestions []ToolSuggestion `json:"suggestions,omitempty"` + Actions []ToolAction `json:"actions,omitempty"` +} + +// ToolSuggestion is a suggestion from the tool response. +type ToolSuggestion struct { + PluginID string `json:"plugin_id"` + Title string `json:"title"` + Description string `json:"description"` +} + +// ToolAction is an action from the tool response instructing the app to execute something. +type ToolAction struct { + Type string `json:"type"` + Title string `json:"title"` + Description string `json:"description,omitempty"` + Params map[string]any `json:"params,omitempty"` + AutoExecute bool `json:"auto_execute"` +} + +// TokenSearchResult contains tokens returned by the find_token MCP tool. +type TokenSearchResult struct { + Tokens []Token `json:"tokens"` +} + +// Token represents a cryptocurrency token with its on-chain deployments. +type Token struct { + ID string `json:"id"` + Name string `json:"name"` + Symbol string `json:"symbol"` + MarketCapRank int `json:"market_cap_rank"` + Logo string `json:"logo"` + Deployments []TokenDeployment `json:"deployments"` +} + +// TokenDeployment represents a token's deployment on a specific chain. +type TokenDeployment struct { + Chain string `json:"chain"` + ContractAddress string `json:"contract_address"` + Decimals int `json:"decimals"` +} diff --git a/internal/service/auth.go b/internal/service/auth.go deleted file mode 100644 index f413c04..0000000 --- a/internal/service/auth.go +++ /dev/null @@ -1,55 +0,0 @@ -package service - -import ( - "errors" - - "github.com/golang-jwt/jwt/v5" -) - -// TokenTypeAccess is the expected token type for access tokens. -const TokenTypeAccess = "access" - -// Claims represents the JWT claims structure used by the verifier. -type Claims struct { - jwt.RegisteredClaims - PublicKey string `json:"public_key"` - TokenID string `json:"token_id"` - TokenType string `json:"token_type"` -} - -// AuthService handles JWT token validation. -type AuthService struct { - jwtSecret []byte -} - -// NewAuthService creates a new AuthService with the given JWT secret. -func NewAuthService(secret string) *AuthService { - return &AuthService{jwtSecret: []byte(secret)} -} - -// ValidateToken validates a JWT token and returns the claims. -func (a *AuthService) ValidateToken(tokenStr string) (*Claims, error) { - claims := &Claims{} - token, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.New("unexpected signing method") - } - return a.jwtSecret, nil - }) - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("invalid or expired token") - } - if claims.PublicKey == "" { - return nil, errors.New("token missing public key") - } - if claims.TokenID == "" { - return nil, errors.New("token missing token ID") - } - if claims.TokenType != TokenTypeAccess { - return nil, errors.New("access token required") - } - return claims, nil -} diff --git a/internal/service/mcp/client.go b/internal/service/mcp/client.go new file mode 100644 index 0000000..ebf01e1 --- /dev/null +++ b/internal/service/mcp/client.go @@ -0,0 +1,130 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync/atomic" + "time" +) + +type Client struct { + baseURL string + httpClient *http.Client + nextID atomic.Int64 +} + +func NewClient(baseURL string) *Client { + return &Client{ + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 60 * time.Second, + }, + } +} + +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params any `json:"params"` +} + +type toolCallParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type toolResult struct { + Content []toolContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type toolContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +func (c *Client) CallTool(ctx context.Context, toolName string, args map[string]any) (string, error) { + id := c.nextID.Add(1) + + rpcReq := jsonRPCRequest{ + JSONRPC: "2.0", + ID: id, + Method: "tools/call", + Params: toolCallParams{ + Name: toolName, + Arguments: args, + }, + } + + body, err := json.Marshal(rpcReq) + if err != nil { + return "", fmt.Errorf("marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/mcp", bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode >= 400 { + return "", fmt.Errorf("MCP server error (status %d): %s", resp.StatusCode, string(respBody)) + } + + var rpcResp jsonRPCResponse + err = json.Unmarshal(respBody, &rpcResp) + if err != nil { + return "", fmt.Errorf("unmarshal response: %w", err) + } + + if rpcResp.Error != nil { + return "", fmt.Errorf("MCP error %d: %s", rpcResp.Error.Code, rpcResp.Error.Message) + } + + var result toolResult + err = json.Unmarshal(rpcResp.Result, &result) + if err != nil { + return "", fmt.Errorf("unmarshal tool result: %w", err) + } + + if result.IsError { + if len(result.Content) > 0 { + return "", fmt.Errorf("tool error: %s", result.Content[0].Text) + } + return "", fmt.Errorf("tool returned error with no content") + } + + if len(result.Content) == 0 { + return "", fmt.Errorf("tool returned empty result") + } + + return result.Content[0].Text, nil +} diff --git a/internal/service/mcp/swap_adapter.go b/internal/service/mcp/swap_adapter.go new file mode 100644 index 0000000..8d66436 --- /dev/null +++ b/internal/service/mcp/swap_adapter.go @@ -0,0 +1,78 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/vultisig/agent-backend/internal/service/agent" +) + +type SwapTxAdapter struct { + client *Client +} + +func NewSwapTxAdapter(client *Client) *SwapTxAdapter { + return &SwapTxAdapter{client: client} +} + +func (a *SwapTxAdapter) BuildSwapTx(ctx context.Context, req agent.SwapTxBuildRequest) (*agent.SwapTxBuildResponse, error) { + args := map[string]any{ + "from_chain": req.FromChain, + "from_symbol": req.FromSymbol, + "to_chain": req.ToChain, + "to_symbol": req.ToSymbol, + "amount": req.Amount, + "sender": req.Sender, + "destination": req.Destination, + } + + if req.FromAddress != "" { + args["from_address"] = req.FromAddress + } + if req.FromDecimals != nil { + args["from_decimals"] = *req.FromDecimals + } + if req.ToAddress != "" { + args["to_address"] = req.ToAddress + } + if req.ToDecimals != nil { + args["to_decimals"] = *req.ToDecimals + } + + resultText, err := a.client.CallTool(ctx, "build_swap_tx", args) + if err != nil { + return nil, fmt.Errorf("MCP build_swap_tx: %w", err) + } + + var mcpResp mcpSwapResponse + err = json.Unmarshal([]byte(resultText), &mcpResp) + if err != nil { + return nil, fmt.Errorf("unmarshal MCP swap response: %w", err) + } + + resp := &agent.SwapTxBuildResponse{ + Provider: mcpResp.Provider, + ExpectedOutput: mcpResp.ExpectedOutput, + MinimumOutput: mcpResp.MinimumOutput, + NeedsApproval: mcpResp.NeedsApproval, + SwapTx: mcpResp.SwapTx, + Memo: mcpResp.Memo, + } + + if mcpResp.NeedsApproval && mcpResp.ApprovalTx != nil { + resp.ApprovalTx = mcpResp.ApprovalTx + } + + return resp, nil +} + +type mcpSwapResponse struct { + Provider string `json:"provider"` + ExpectedOutput string `json:"expected_output"` + MinimumOutput string `json:"minimum_output"` + NeedsApproval bool `json:"needs_approval"` + ApprovalTx json.RawMessage `json:"approval_tx,omitempty"` + SwapTx json.RawMessage `json:"swap_tx"` + Memo string `json:"memo,omitempty"` +} diff --git a/internal/service/scheduler/scheduler.go b/internal/service/scheduler/scheduler.go new file mode 100644 index 0000000..c63770e --- /dev/null +++ b/internal/service/scheduler/scheduler.go @@ -0,0 +1,245 @@ +package scheduler + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgtype" + "github.com/sirupsen/logrus" + + "github.com/vultisig/agent-backend/internal/ai" + "github.com/vultisig/agent-backend/internal/config" + "github.com/vultisig/agent-backend/internal/service/agent" + "github.com/vultisig/agent-backend/internal/storage/postgres" + "github.com/vultisig/agent-backend/internal/types" +) + +// Scheduler polls for due tasks and executes them headlessly. +type Scheduler struct { + agentService *agent.AgentService + taskRepo *postgres.ScheduledTaskRepository + convRepo *postgres.ConversationRepository + msgRepo *postgres.MessageRepository + mcpProvider agent.MCPToolProvider + cfg config.SchedulerConfig + aiModel string + logger *logrus.Logger +} + +// New creates a new Scheduler. +func New( + agentService *agent.AgentService, + taskRepo *postgres.ScheduledTaskRepository, + convRepo *postgres.ConversationRepository, + msgRepo *postgres.MessageRepository, + mcpProvider agent.MCPToolProvider, + cfg config.SchedulerConfig, + aiModel string, + logger *logrus.Logger, +) *Scheduler { + return &Scheduler{ + agentService: agentService, + taskRepo: taskRepo, + convRepo: convRepo, + msgRepo: msgRepo, + mcpProvider: mcpProvider, + cfg: cfg, + aiModel: aiModel, + logger: logger, + } +} + +// Run starts the scheduler loop. Blocks until ctx is cancelled. +func (s *Scheduler) Run(ctx context.Context) error { + interval := time.Duration(s.cfg.PollIntervalSeconds) * time.Second + ticker := time.NewTicker(interval) + defer ticker.Stop() + + s.logger.WithField("poll_interval", interval).Info("scheduler started") + + for { + select { + case <-ctx.Done(): + s.logger.Info("scheduler stopping") + return ctx.Err() + case <-ticker.C: + s.poll(ctx) + } + } +} + +func (s *Scheduler) poll(ctx context.Context) { + tasks, err := s.taskRepo.ClaimDueTasks(ctx) + if err != nil { + s.logger.WithError(err).Error("failed to claim due tasks") + return + } + + if len(tasks) == 0 { + return + } + + s.logger.WithField("count", len(tasks)).Info("claimed due tasks") + + for _, task := range tasks { + s.executeTask(ctx, task) + } +} + +func (s *Scheduler) executeTask(ctx context.Context, task types.ScheduledTask) { + log := s.logger.WithFields(logrus.Fields{ + "task_id": task.ID, + "public_key": task.PublicKey, + "intent": task.Intent, + }) + + // Create a task run record + run, err := s.taskRepo.CreateTaskRun(ctx, task.ID) + if err != nil { + log.WithError(err).Error("failed to create task run") + s.advanceTask(ctx, task) + return + } + + // Build system prompt for headless execution + systemPrompt := buildSchedulerPrompt() + + // Build user message from frozen intent + context + userMessage := fmt.Sprintf("Execute this scheduled task:\n\nIntent: %s\n\nContext: %s", task.Intent, string(task.Context)) + + // Get safe tool set (includes MCP observation tools) + tools := s.schedulerTools(ctx) + + // Execute headlessly + result, err := s.agentService.RunHeadless(ctx, systemPrompt, userMessage, s.aiModel, tools, task.PublicKey) + + if err != nil { + log.WithError(err).Error("headless execution failed") + errStr := err.Error() + _ = s.taskRepo.CompleteTaskRun(ctx, run.ID, types.TaskRunFailed, nil, &errStr, nil) + s.advanceTask(ctx, task) + return + } + + // Resolve conversation: use originating conversation, or create a new one as fallback + var convID uuid.UUID + if task.ConversationID != nil { + convID = *task.ConversationID + } else { + conv, err := s.convRepo.Create(ctx, task.PublicKey) + if err != nil { + log.WithError(err).Error("failed to create conversation for task result") + resultJSON, _ := json.Marshal(result) + _ = s.taskRepo.CompleteTaskRun(ctx, run.ID, types.TaskRunSuccess, resultJSON, nil, nil) + s.advanceTask(ctx, task) + return + } + convID = conv.ID + title := truncateTitle("Scheduled: "+task.Intent, 60) + _ = s.convRepo.UpdateTitle(ctx, convID, task.PublicKey, title) + } + + // Store the assistant message in the conversation + assistantMsg := &types.Message{ + ConversationID: convID, + Role: types.RoleAssistant, + Content: result.Response, + ContentType: "text", + } + + metadata, _ := json.Marshal(map[string]any{ + "scheduled_task_id": task.ID.String(), + }) + assistantMsg.Metadata = metadata + + if err := s.msgRepo.Create(ctx, assistantMsg); err != nil { + log.WithError(err).Error("failed to store assistant message") + } + + // Complete the run + resultJSON, _ := json.Marshal(result) + _ = s.taskRepo.CompleteTaskRun(ctx, run.ID, types.TaskRunSuccess, resultJSON, nil, &convID) + + // TODO: integrate push notifications to alert user of scheduled task results + + log.WithField("conversation_id", convID).Info("task executed successfully") + + s.advanceTask(ctx, task) +} + +func (s *Scheduler) advanceTask(ctx context.Context, task types.ScheduledTask) { + log := s.logger.WithField("task_id", task.ID) + + // For completed tasks, keep the existing next_run_at (column is NOT NULL) + keepTime := pgtype.Timestamptz{Time: task.NextRunAt, Valid: true} + + if !task.IsRecurring() { + // One-time task — mark completed + if err := s.taskRepo.AdvanceTask(ctx, task.ID, keepTime, types.TaskStatusCompleted); err != nil { + log.WithError(err).Error("failed to advance one-time task") + } + return + } + + // Check if max runs reached + newRunCount := task.RunCount + 1 + if task.MaxRuns != nil && newRunCount >= *task.MaxRuns { + if err := s.taskRepo.AdvanceTask(ctx, task.ID, keepTime, types.TaskStatusCompleted); err != nil { + log.WithError(err).Error("failed to complete max-runs task") + } + return + } + + // Compute next run time by adding interval + next := time.Now().Add(time.Duration(*task.IntervalSeconds) * time.Second) + if err := s.taskRepo.AdvanceTask(ctx, task.ID, pgtype.Timestamptz{Time: next, Valid: true}, types.TaskStatusActive); err != nil { + log.WithError(err).Error("failed to advance recurring task") + } +} + +func buildSchedulerPrompt() string { + return `You are the Vultisig AI assistant running a scheduled task. + +IMPORTANT: You are running WITHOUT a live user session. The user is NOT present and CANNOT approve or sign anything. Never attempt to auto-execute transactions. + +Your job is to execute the task's intent using the frozen context provided, then respond with a clear summary of what you found or did. The user will see your response later as a notification. + +Be concise and actionable in your response. If you found something noteworthy (e.g., a price threshold was crossed), highlight it clearly.` +} + +// schedulerTools returns the safe tool set for headless execution. +// Includes observation tools (respond, billing, plugins, MCP tools like find_token/get_balance) +// but never scheduling tools (no recursive scheduling). +func (s *Scheduler) schedulerTools(ctx context.Context) []ai.Tool { + tools := []ai.Tool{ + agent.RespondToUserTool, + agent.CheckBillingStatusTool, + agent.CheckPluginInstalledTool, + agent.GetRecipeSchemaTool, + agent.SuggestPolicyTool, + agent.CreateSuggestionTool, + } + + // Include MCP tools (find_token, get_balance, get_token_price, etc.) + if s.mcpProvider != nil { + mcpTools := s.mcpProvider.GetTools(ctx) + if len(mcpTools) > 0 { + tools = append(tools, mcpTools...) + } + } + + // Never include: schedule_task, list_scheduled_tasks, update_scheduled_task, + // cancel_scheduled_task (no recursive scheduling) + + return tools +} + +func truncateTitle(title string, maxLen int) string { + if len(title) <= maxLen { + return title + } + return title[:maxLen-3] + "..." +} diff --git a/internal/service/verifier/client.go b/internal/service/verifier/client.go index 96f8e27..5164eab 100644 --- a/internal/service/verifier/client.go +++ b/internal/service/verifier/client.go @@ -7,9 +7,12 @@ import ( "fmt" "io" "net/http" + "regexp" "time" ) +var validPluginID = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + // Client is a client for the verifier service. type Client struct { baseURL string @@ -111,6 +114,46 @@ type SuggestRequest struct { Configuration map[string]any `json:"configuration"` } +// AuthMeResponse is the response from GET /auth/me. +type AuthMeResponse struct { + Data struct { + PublicKey string `json:"public_key"` + } `json:"data"` +} + +// GetMe validates an access token via the verifier and returns the user's public key. +func (c *Client) GetMe(ctx context.Context, accessToken string) (string, error) { + url := fmt.Sprintf("%s/auth/me", c.baseURL) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("http request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + } + + var authResp AuthMeResponse + if err := json.NewDecoder(resp.Body).Decode(&authResp); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } + + if authResp.Data.PublicKey == "" { + return "", fmt.Errorf("empty public key in response") + } + + return authResp.Data.PublicKey, nil +} + // FeeStatus represents the billing status for a user. type FeeStatus struct { IsTrialActive bool `json:"is_trial_active"` @@ -183,6 +226,9 @@ func (c *Client) IsPluginInstalled(ctx context.Context, accessToken, pluginID st // GetRecipeSchema fetches the recipe specification for a plugin. func (c *Client) GetRecipeSchema(ctx context.Context, pluginID string) (*RecipeSchema, error) { + if !validPluginID.MatchString(pluginID) { + return nil, fmt.Errorf("invalid plugin ID: %q", pluginID) + } url := fmt.Sprintf("%s/plugins/%s/recipe-specification", c.baseURL, pluginID) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -211,6 +257,9 @@ func (c *Client) GetRecipeSchema(ctx context.Context, pluginID string) (*RecipeS // GetPolicySuggest calls the plugin's suggest endpoint to build a policy. func (c *Client) GetPolicySuggest(ctx context.Context, pluginID string, configuration map[string]any) (*PolicySuggest, error) { + if !validPluginID.MatchString(pluginID) { + return nil, fmt.Errorf("invalid plugin ID: %q", pluginID) + } url := fmt.Sprintf("%s/plugins/%s/recipe-specification/suggest", c.baseURL, pluginID) body, err := json.Marshal(SuggestRequest{Configuration: configuration}) diff --git a/internal/storage/postgres/conversation.go b/internal/storage/postgres/conversation.go index 367baf1..b462709 100644 --- a/internal/storage/postgres/conversation.go +++ b/internal/storage/postgres/conversation.go @@ -135,6 +135,46 @@ func (r *ConversationRepository) UpdateSummaryWithCursor(ctx context.Context, id return nil } +// UpdateVaultInfo updates the vault keys for a conversation. +func (r *ConversationRepository) UpdateVaultInfo(ctx context.Context, id uuid.UUID, publicKey string, ecdsa, eddsa, chaincode string) error { + rowsAffected, err := r.q.UpdateVaultInfo(ctx, &queries.UpdateVaultInfoParams{ + EcdsaPublicKey: stringPtrToPgtext(&ecdsa), + EddsaPublicKey: stringPtrToPgtext(&eddsa), + ChaincodeHex: stringPtrToPgtext(&chaincode), + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + return fmt.Errorf("update vault info: %w", err) + } + if rowsAffected == 0 { + return ErrNotFound + } + return nil +} + +// GetVaultInfo returns the vault keys for a conversation. +func (r *ConversationRepository) GetVaultInfo(ctx context.Context, id uuid.UUID, publicKey string) (*types.VaultInfo, error) { + row, err := r.q.GetVaultInfo(ctx, &queries.GetVaultInfoParams{ + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("get vault info: %w", err) + } + if !row.EcdsaPublicKey.Valid || !row.EddsaPublicKey.Valid || !row.ChaincodeHex.Valid { + return nil, nil + } + return &types.VaultInfo{ + ECDSAPublicKey: row.EcdsaPublicKey.String, + EDDSAPublicKey: row.EddsaPublicKey.String, + ChaincodeHex: row.ChaincodeHex.String, + }, nil +} + // GetSummaryWithCursor returns the summary and summary_up_to cursor of a conversation. func (r *ConversationRepository) GetSummaryWithCursor(ctx context.Context, id uuid.UUID, publicKey string) (*string, *time.Time, error) { row, err := r.q.GetConversationSummaryWithCursor(ctx, &queries.GetConversationSummaryWithCursorParams{ diff --git a/internal/storage/postgres/convert.go b/internal/storage/postgres/convert.go index 282126c..daa22a6 100644 --- a/internal/storage/postgres/convert.go +++ b/internal/storage/postgres/convert.go @@ -72,7 +72,7 @@ func conversationFromDB(c *queries.AgentConversation) *types.Conversation { if c == nil { return nil } - return &types.Conversation{ + conv := &types.Conversation{ ID: pgtypeToUUID(c.ID), PublicKey: c.PublicKey, Title: pgtextToStringPtr(c.Title), @@ -82,6 +82,14 @@ func conversationFromDB(c *queries.AgentConversation) *types.Conversation { UpdatedAt: pgtimestamptzToTime(c.UpdatedAt), ArchivedAt: pgtimestamptzToTimePtr(c.ArchivedAt), } + if c.EcdsaPublicKey.Valid && c.EddsaPublicKey.Valid && c.ChaincodeHex.Valid { + conv.VaultInfo = &types.VaultInfo{ + ECDSAPublicKey: c.EcdsaPublicKey.String, + EDDSAPublicKey: c.EddsaPublicKey.String, + ChaincodeHex: c.ChaincodeHex.String, + } + } + return conv } func conversationsFromDB(cs []*queries.AgentConversation) []types.Conversation { @@ -140,3 +148,75 @@ func userMemoryFromDB(m *queries.AgentUserMemory) *types.UserMemory { UpdatedAt: pgtimestamptzToTime(m.UpdatedAt), } } + +// Scheduled task conversions + +func int4ToInt32Ptr(i pgtype.Int4) *int32 { + if !i.Valid { + return nil + } + return &i.Int32 +} + +func int32PtrToPgint4(i *int32) pgtype.Int4 { + if i == nil { + return pgtype.Int4{Valid: false} + } + return pgtype.Int4{Int32: *i, Valid: true} +} + +func scheduledTaskFromDB(t *queries.AgentScheduledTask) *types.ScheduledTask { + if t == nil { + return nil + } + st := &types.ScheduledTask{ + ID: pgtypeToUUID(t.ID), + PublicKey: t.PublicKey, + Intent: t.Intent, + Context: json.RawMessage(t.Context), + NextRunAt: pgtimestamptzToTime(t.NextRunAt), + IntervalSeconds: int4ToInt32Ptr(t.IntervalSeconds), + MaxRuns: int4ToInt32Ptr(t.MaxRuns), + RunCount: t.RunCount, + Status: types.TaskStatus(t.Status), + CreatedAt: pgtimestamptzToTime(t.CreatedAt), + UpdatedAt: pgtimestamptzToTime(t.UpdatedAt), + } + if t.ConversationID.Valid { + id := pgtypeToUUID(t.ConversationID) + st.ConversationID = &id + } + return st +} + +func scheduledTasksFromDB(ts []*queries.AgentScheduledTask) []types.ScheduledTask { + result := make([]types.ScheduledTask, len(ts)) + for i, t := range ts { + task := scheduledTaskFromDB(t) + if task != nil { + result[i] = *task + } + } + return result +} + +func taskRunFromDB(r *queries.AgentTaskRun) *types.TaskRun { + if r == nil { + return nil + } + tr := &types.TaskRun{ + ID: pgtypeToUUID(r.ID), + TaskID: pgtypeToUUID(r.TaskID), + Status: types.TaskRunStatus(r.Status), + Result: json.RawMessage(r.Result), + Error: pgtextToStringPtr(r.Error), + Notified: r.Notified, + StartedAt: pgtimestamptzToTime(r.StartedAt), + FinishedAt: pgtimestamptzToTimePtr(r.FinishedAt), + } + if r.ConversationID.Valid { + id := pgtypeToUUID(r.ConversationID) + tr.ConversationID = &id + } + return tr +} diff --git a/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql b/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql new file mode 100644 index 0000000..d84e2ef --- /dev/null +++ b/internal/storage/postgres/migrations/20260221000001_add_vault_info.sql @@ -0,0 +1,11 @@ +-- +goose Up +ALTER TABLE agent_conversations + ADD COLUMN ecdsa_public_key TEXT, + ADD COLUMN eddsa_public_key TEXT, + ADD COLUMN chaincode_hex TEXT; + +-- +goose Down +ALTER TABLE agent_conversations + DROP COLUMN IF EXISTS ecdsa_public_key, + DROP COLUMN IF EXISTS eddsa_public_key, + DROP COLUMN IF EXISTS chaincode_hex; diff --git a/internal/storage/postgres/migrations/20260224000001_create_scheduled_tasks.sql b/internal/storage/postgres/migrations/20260224000001_create_scheduled_tasks.sql new file mode 100644 index 0000000..a8e121d --- /dev/null +++ b/internal/storage/postgres/migrations/20260224000001_create_scheduled_tasks.sql @@ -0,0 +1,50 @@ +-- +goose Up +-- +goose StatementBegin + +CREATE TYPE agent_task_status AS ENUM ('active', 'paused', 'completed', 'cancelled'); +CREATE TYPE agent_task_run_status AS ENUM ('running', 'success', 'failed'); + +CREATE TABLE agent_scheduled_tasks ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + public_key VARCHAR(66) NOT NULL, + conversation_id UUID REFERENCES agent_conversations(id) ON DELETE SET NULL, + intent TEXT NOT NULL, + context JSONB NOT NULL DEFAULT '{}', + next_run_at TIMESTAMPTZ NOT NULL, + interval_seconds INT, + max_runs INT, + run_count INT NOT NULL DEFAULT 0, + status agent_task_status NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_agent_scheduled_tasks_public_key ON agent_scheduled_tasks(public_key); +CREATE INDEX idx_agent_scheduled_tasks_due ON agent_scheduled_tasks(next_run_at) WHERE status = 'active'; + +CREATE TABLE agent_task_runs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + task_id UUID NOT NULL REFERENCES agent_scheduled_tasks(id) ON DELETE CASCADE, + conversation_id UUID REFERENCES agent_conversations(id) ON DELETE SET NULL, + status agent_task_run_status NOT NULL DEFAULT 'running', + result JSONB, + error TEXT, + notified BOOLEAN NOT NULL DEFAULT FALSE, + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + finished_at TIMESTAMPTZ +); + +CREATE INDEX idx_agent_task_runs_task_id ON agent_task_runs(task_id); +CREATE INDEX idx_agent_task_runs_unnotified ON agent_task_runs(notified) WHERE notified = FALSE; + +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin + +DROP TABLE IF EXISTS agent_task_runs; +DROP TABLE IF EXISTS agent_scheduled_tasks; +DROP TYPE IF EXISTS agent_task_run_status; +DROP TYPE IF EXISTS agent_task_status; + +-- +goose StatementEnd diff --git a/internal/storage/postgres/queries/conversations.sql.go b/internal/storage/postgres/queries/conversations.sql.go index 819c0a1..423a5af 100644 --- a/internal/storage/postgres/queries/conversations.sql.go +++ b/internal/storage/postgres/queries/conversations.sql.go @@ -46,7 +46,7 @@ const createConversation = `-- name: CreateConversation :one INSERT INTO agent_conversations (public_key) VALUES ($1) -RETURNING id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at +RETURNING id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at ` // Conversations table queries @@ -59,6 +59,9 @@ func (q *Queries) CreateConversation(ctx context.Context, publicKey string) (*Ag &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -67,7 +70,7 @@ func (q *Queries) CreateConversation(ctx context.Context, publicKey string) (*Ag } const getConversationByID = `-- name: GetConversationByID :one -SELECT id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at FROM agent_conversations +SELECT id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at FROM agent_conversations WHERE id = $1 AND public_key = $2 AND archived_at IS NULL ` @@ -85,6 +88,9 @@ func (q *Queries) GetConversationByID(ctx context.Context, arg *GetConversationB &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -114,8 +120,31 @@ func (q *Queries) GetConversationSummaryWithCursor(ctx context.Context, arg *Get return &i, err } +const getVaultInfo = `-- name: GetVaultInfo :one +SELECT ecdsa_public_key, eddsa_public_key, chaincode_hex FROM agent_conversations +WHERE id = $1 AND public_key = $2 AND archived_at IS NULL +` + +type GetVaultInfoParams struct { + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +type GetVaultInfoRow struct { + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` +} + +func (q *Queries) GetVaultInfo(ctx context.Context, arg *GetVaultInfoParams) (*GetVaultInfoRow, error) { + row := q.db.QueryRow(ctx, getVaultInfo, arg.ID, arg.PublicKey) + var i GetVaultInfoRow + err := row.Scan(&i.EcdsaPublicKey, &i.EddsaPublicKey, &i.ChaincodeHex) + return &i, err +} + const listConversations = `-- name: ListConversations :many -SELECT id, public_key, title, summary, summary_up_to, created_at, updated_at, archived_at FROM agent_conversations +SELECT id, public_key, title, summary, summary_up_to, ecdsa_public_key, eddsa_public_key, chaincode_hex, created_at, updated_at, archived_at FROM agent_conversations WHERE public_key = $1 AND archived_at IS NULL ORDER BY updated_at DESC LIMIT $2 OFFSET $3 @@ -142,6 +171,9 @@ func (q *Queries) ListConversations(ctx context.Context, arg *ListConversationsP &i.Title, &i.Summary, &i.SummaryUpTo, + &i.EcdsaPublicKey, + &i.EddsaPublicKey, + &i.ChaincodeHex, &i.CreatedAt, &i.UpdatedAt, &i.ArchivedAt, @@ -201,3 +233,31 @@ func (q *Queries) UpdateConversationTitle(ctx context.Context, arg *UpdateConver } return result.RowsAffected(), nil } + +const updateVaultInfo = `-- name: UpdateVaultInfo :execrows +UPDATE agent_conversations +SET ecdsa_public_key = $1, eddsa_public_key = $2, chaincode_hex = $3, updated_at = NOW() +WHERE id = $4 AND public_key = $5 AND archived_at IS NULL +` + +type UpdateVaultInfoParams struct { + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +func (q *Queries) UpdateVaultInfo(ctx context.Context, arg *UpdateVaultInfoParams) (int64, error) { + result, err := q.db.Exec(ctx, updateVaultInfo, + arg.EcdsaPublicKey, + arg.EddsaPublicKey, + arg.ChaincodeHex, + arg.ID, + arg.PublicKey, + ) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} diff --git a/internal/storage/postgres/queries/models.go b/internal/storage/postgres/queries/models.go index abfb866..707f830 100644 --- a/internal/storage/postgres/queries/models.go +++ b/internal/storage/postgres/queries/models.go @@ -54,15 +54,105 @@ func (ns NullAgentMessageRole) Value() (driver.Value, error) { return string(ns.AgentMessageRole), nil } +type AgentTaskRunStatus string + +const ( + AgentTaskRunStatusRunning AgentTaskRunStatus = "running" + AgentTaskRunStatusSuccess AgentTaskRunStatus = "success" + AgentTaskRunStatusFailed AgentTaskRunStatus = "failed" +) + +func (e *AgentTaskRunStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AgentTaskRunStatus(s) + case string: + *e = AgentTaskRunStatus(s) + default: + return fmt.Errorf("unsupported scan type for AgentTaskRunStatus: %T", src) + } + return nil +} + +type NullAgentTaskRunStatus struct { + AgentTaskRunStatus AgentTaskRunStatus `json:"agent_task_run_status"` + Valid bool `json:"valid"` // Valid is true if AgentTaskRunStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAgentTaskRunStatus) Scan(value interface{}) error { + if value == nil { + ns.AgentTaskRunStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AgentTaskRunStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAgentTaskRunStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AgentTaskRunStatus), nil +} + +type AgentTaskStatus string + +const ( + AgentTaskStatusActive AgentTaskStatus = "active" + AgentTaskStatusPaused AgentTaskStatus = "paused" + AgentTaskStatusCompleted AgentTaskStatus = "completed" + AgentTaskStatusCancelled AgentTaskStatus = "cancelled" +) + +func (e *AgentTaskStatus) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = AgentTaskStatus(s) + case string: + *e = AgentTaskStatus(s) + default: + return fmt.Errorf("unsupported scan type for AgentTaskStatus: %T", src) + } + return nil +} + +type NullAgentTaskStatus struct { + AgentTaskStatus AgentTaskStatus `json:"agent_task_status"` + Valid bool `json:"valid"` // Valid is true if AgentTaskStatus is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullAgentTaskStatus) Scan(value interface{}) error { + if value == nil { + ns.AgentTaskStatus, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.AgentTaskStatus.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullAgentTaskStatus) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.AgentTaskStatus), nil +} + type AgentConversation struct { - ID pgtype.UUID `json:"id"` - PublicKey string `json:"public_key"` - Title pgtype.Text `json:"title"` - Summary pgtype.Text `json:"summary"` - SummaryUpTo pgtype.Timestamptz `json:"summary_up_to"` - CreatedAt pgtype.Timestamptz `json:"created_at"` - UpdatedAt pgtype.Timestamptz `json:"updated_at"` - ArchivedAt pgtype.Timestamptz `json:"archived_at"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` + Title pgtype.Text `json:"title"` + Summary pgtype.Text `json:"summary"` + SummaryUpTo pgtype.Timestamptz `json:"summary_up_to"` + EcdsaPublicKey pgtype.Text `json:"ecdsa_public_key"` + EddsaPublicKey pgtype.Text `json:"eddsa_public_key"` + ChaincodeHex pgtype.Text `json:"chaincode_hex"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` + ArchivedAt pgtype.Timestamptz `json:"archived_at"` } type AgentMessage struct { @@ -76,6 +166,33 @@ type AgentMessage struct { CreatedAt pgtype.Timestamptz `json:"created_at"` } +type AgentScheduledTask struct { + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` + ConversationID pgtype.UUID `json:"conversation_id"` + Intent string `json:"intent"` + Context []byte `json:"context"` + NextRunAt pgtype.Timestamptz `json:"next_run_at"` + IntervalSeconds pgtype.Int4 `json:"interval_seconds"` + MaxRuns pgtype.Int4 `json:"max_runs"` + RunCount int32 `json:"run_count"` + Status AgentTaskStatus `json:"status"` + CreatedAt pgtype.Timestamptz `json:"created_at"` + UpdatedAt pgtype.Timestamptz `json:"updated_at"` +} + +type AgentTaskRun struct { + ID pgtype.UUID `json:"id"` + TaskID pgtype.UUID `json:"task_id"` + ConversationID pgtype.UUID `json:"conversation_id"` + Status AgentTaskRunStatus `json:"status"` + Result []byte `json:"result"` + Error pgtype.Text `json:"error"` + Notified bool `json:"notified"` + StartedAt pgtype.Timestamptz `json:"started_at"` + FinishedAt pgtype.Timestamptz `json:"finished_at"` +} + type AgentUserMemory struct { PublicKey string `json:"public_key"` Content string `json:"content"` diff --git a/internal/storage/postgres/queries/scheduled_tasks.sql.go b/internal/storage/postgres/queries/scheduled_tasks.sql.go new file mode 100644 index 0000000..d50311a --- /dev/null +++ b/internal/storage/postgres/queries/scheduled_tasks.sql.go @@ -0,0 +1,318 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: scheduled_tasks.sql + +package queries + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const advanceTask = `-- name: AdvanceTask :exec +UPDATE agent_scheduled_tasks +SET run_count = run_count + 1, + next_run_at = $2, + status = $3, + updated_at = NOW() +WHERE id = $1 +` + +type AdvanceTaskParams struct { + ID pgtype.UUID `json:"id"` + NextRunAt pgtype.Timestamptz `json:"next_run_at"` + Status AgentTaskStatus `json:"status"` +} + +func (q *Queries) AdvanceTask(ctx context.Context, arg *AdvanceTaskParams) error { + _, err := q.db.Exec(ctx, advanceTask, arg.ID, arg.NextRunAt, arg.Status) + return err +} + +const cancelScheduledTask = `-- name: CancelScheduledTask :execrows +UPDATE agent_scheduled_tasks +SET status = 'cancelled', updated_at = NOW() +WHERE id = $1 AND public_key = $2 AND status IN ('active', 'paused') +` + +type CancelScheduledTaskParams struct { + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +func (q *Queries) CancelScheduledTask(ctx context.Context, arg *CancelScheduledTaskParams) (int64, error) { + result, err := q.db.Exec(ctx, cancelScheduledTask, arg.ID, arg.PublicKey) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + +const claimDueTasks = `-- name: ClaimDueTasks :many +UPDATE agent_scheduled_tasks +SET status = 'paused', updated_at = NOW() +WHERE status = 'active' AND next_run_at <= NOW() +RETURNING id, public_key, conversation_id, intent, context, next_run_at, interval_seconds, max_runs, run_count, status, created_at, updated_at +` + +func (q *Queries) ClaimDueTasks(ctx context.Context) ([]*AgentScheduledTask, error) { + rows, err := q.db.Query(ctx, claimDueTasks) + if err != nil { + return nil, err + } + defer rows.Close() + items := []*AgentScheduledTask{} + for rows.Next() { + var i AgentScheduledTask + if err := rows.Scan( + &i.ID, + &i.PublicKey, + &i.ConversationID, + &i.Intent, + &i.Context, + &i.NextRunAt, + &i.IntervalSeconds, + &i.MaxRuns, + &i.RunCount, + &i.Status, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const completeTaskRun = `-- name: CompleteTaskRun :exec +UPDATE agent_task_runs +SET status = $2, result = $3, error = $4, + conversation_id = $5, finished_at = NOW() +WHERE id = $1 +` + +type CompleteTaskRunParams struct { + ID pgtype.UUID `json:"id"` + Status AgentTaskRunStatus `json:"status"` + Result []byte `json:"result"` + Error pgtype.Text `json:"error"` + ConversationID pgtype.UUID `json:"conversation_id"` +} + +func (q *Queries) CompleteTaskRun(ctx context.Context, arg *CompleteTaskRunParams) error { + _, err := q.db.Exec(ctx, completeTaskRun, + arg.ID, + arg.Status, + arg.Result, + arg.Error, + arg.ConversationID, + ) + return err +} + +const countActiveTasksByPublicKey = `-- name: CountActiveTasksByPublicKey :one +SELECT COUNT(*) FROM agent_scheduled_tasks +WHERE public_key = $1 AND status = 'active' +` + +func (q *Queries) CountActiveTasksByPublicKey(ctx context.Context, publicKey string) (int64, error) { + row := q.db.QueryRow(ctx, countActiveTasksByPublicKey, publicKey) + var count int64 + err := row.Scan(&count) + return count, err +} + +const createScheduledTask = `-- name: CreateScheduledTask :one +INSERT INTO agent_scheduled_tasks ( + public_key, conversation_id, intent, context, + next_run_at, interval_seconds, max_runs, status +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, 'active' +) RETURNING id, public_key, conversation_id, intent, context, next_run_at, interval_seconds, max_runs, run_count, status, created_at, updated_at +` + +type CreateScheduledTaskParams struct { + PublicKey string `json:"public_key"` + ConversationID pgtype.UUID `json:"conversation_id"` + Intent string `json:"intent"` + Context []byte `json:"context"` + NextRunAt pgtype.Timestamptz `json:"next_run_at"` + IntervalSeconds pgtype.Int4 `json:"interval_seconds"` + MaxRuns pgtype.Int4 `json:"max_runs"` +} + +func (q *Queries) CreateScheduledTask(ctx context.Context, arg *CreateScheduledTaskParams) (*AgentScheduledTask, error) { + row := q.db.QueryRow(ctx, createScheduledTask, + arg.PublicKey, + arg.ConversationID, + arg.Intent, + arg.Context, + arg.NextRunAt, + arg.IntervalSeconds, + arg.MaxRuns, + ) + var i AgentScheduledTask + err := row.Scan( + &i.ID, + &i.PublicKey, + &i.ConversationID, + &i.Intent, + &i.Context, + &i.NextRunAt, + &i.IntervalSeconds, + &i.MaxRuns, + &i.RunCount, + &i.Status, + &i.CreatedAt, + &i.UpdatedAt, + ) + return &i, err +} + +const createTaskRun = `-- name: CreateTaskRun :one +INSERT INTO agent_task_runs (task_id, status) +VALUES ($1, 'running') +RETURNING id, task_id, conversation_id, status, result, error, notified, started_at, finished_at +` + +func (q *Queries) CreateTaskRun(ctx context.Context, taskID pgtype.UUID) (*AgentTaskRun, error) { + row := q.db.QueryRow(ctx, createTaskRun, taskID) + var i AgentTaskRun + err := row.Scan( + &i.ID, + &i.TaskID, + &i.ConversationID, + &i.Status, + &i.Result, + &i.Error, + &i.Notified, + &i.StartedAt, + &i.FinishedAt, + ) + return &i, err +} + +const getScheduledTaskByID = `-- name: GetScheduledTaskByID :one +SELECT id, public_key, conversation_id, intent, context, next_run_at, interval_seconds, max_runs, run_count, status, created_at, updated_at FROM agent_scheduled_tasks +WHERE id = $1 AND public_key = $2 +` + +type GetScheduledTaskByIDParams struct { + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +func (q *Queries) GetScheduledTaskByID(ctx context.Context, arg *GetScheduledTaskByIDParams) (*AgentScheduledTask, error) { + row := q.db.QueryRow(ctx, getScheduledTaskByID, arg.ID, arg.PublicKey) + var i AgentScheduledTask + err := row.Scan( + &i.ID, + &i.PublicKey, + &i.ConversationID, + &i.Intent, + &i.Context, + &i.NextRunAt, + &i.IntervalSeconds, + &i.MaxRuns, + &i.RunCount, + &i.Status, + &i.CreatedAt, + &i.UpdatedAt, + ) + return &i, err +} + +const listActiveScheduledTasks = `-- name: ListActiveScheduledTasks :many +SELECT id, public_key, conversation_id, intent, context, next_run_at, interval_seconds, max_runs, run_count, status, created_at, updated_at FROM agent_scheduled_tasks +WHERE public_key = $1 AND status = 'active' +ORDER BY created_at DESC +` + +func (q *Queries) ListActiveScheduledTasks(ctx context.Context, publicKey string) ([]*AgentScheduledTask, error) { + rows, err := q.db.Query(ctx, listActiveScheduledTasks, publicKey) + if err != nil { + return nil, err + } + defer rows.Close() + items := []*AgentScheduledTask{} + for rows.Next() { + var i AgentScheduledTask + if err := rows.Scan( + &i.ID, + &i.PublicKey, + &i.ConversationID, + &i.Intent, + &i.Context, + &i.NextRunAt, + &i.IntervalSeconds, + &i.MaxRuns, + &i.RunCount, + &i.Status, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateScheduledTask = `-- name: UpdateScheduledTask :one +UPDATE agent_scheduled_tasks SET + intent = COALESCE($1, intent), + context = COALESCE($2, context), + next_run_at = COALESCE($3, next_run_at), + interval_seconds = COALESCE($4, interval_seconds), + max_runs = COALESCE($5, max_runs), + updated_at = NOW() +WHERE id = $6 AND public_key = $7 AND status = 'active' +RETURNING id, public_key, conversation_id, intent, context, next_run_at, interval_seconds, max_runs, run_count, status, created_at, updated_at +` + +type UpdateScheduledTaskParams struct { + Intent pgtype.Text `json:"intent"` + Context []byte `json:"context"` + NextRunAt pgtype.Timestamptz `json:"next_run_at"` + IntervalSeconds pgtype.Int4 `json:"interval_seconds"` + MaxRuns pgtype.Int4 `json:"max_runs"` + ID pgtype.UUID `json:"id"` + PublicKey string `json:"public_key"` +} + +func (q *Queries) UpdateScheduledTask(ctx context.Context, arg *UpdateScheduledTaskParams) (*AgentScheduledTask, error) { + row := q.db.QueryRow(ctx, updateScheduledTask, + arg.Intent, + arg.Context, + arg.NextRunAt, + arg.IntervalSeconds, + arg.MaxRuns, + arg.ID, + arg.PublicKey, + ) + var i AgentScheduledTask + err := row.Scan( + &i.ID, + &i.PublicKey, + &i.ConversationID, + &i.Intent, + &i.Context, + &i.NextRunAt, + &i.IntervalSeconds, + &i.MaxRuns, + &i.RunCount, + &i.Status, + &i.CreatedAt, + &i.UpdatedAt, + ) + return &i, err +} diff --git a/internal/storage/postgres/scheduled_task.go b/internal/storage/postgres/scheduled_task.go new file mode 100644 index 0000000..14f098c --- /dev/null +++ b/internal/storage/postgres/scheduled_task.go @@ -0,0 +1,183 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/vultisig/agent-backend/internal/storage/postgres/queries" + "github.com/vultisig/agent-backend/internal/types" +) + +// ScheduledTaskRepository handles database operations for scheduled tasks. +type ScheduledTaskRepository struct { + q *queries.Queries +} + +// NewScheduledTaskRepository creates a new ScheduledTaskRepository. +func NewScheduledTaskRepository(pool *pgxpool.Pool) *ScheduledTaskRepository { + return &ScheduledTaskRepository{q: queries.New(pool)} +} + +// Create creates a new scheduled task. +func (r *ScheduledTaskRepository) Create(ctx context.Context, task *types.ScheduledTask) (*types.ScheduledTask, error) { + convID := pgtype.UUID{Valid: false} + if task.ConversationID != nil { + convID = uuidToPgtype(*task.ConversationID) + } + result, err := r.q.CreateScheduledTask(ctx, &queries.CreateScheduledTaskParams{ + PublicKey: task.PublicKey, + ConversationID: convID, + Intent: task.Intent, + Context: []byte(task.Context), + NextRunAt: timeToPgtimestamptz(task.NextRunAt), + IntervalSeconds: int32PtrToPgint4(task.IntervalSeconds), + MaxRuns: int32PtrToPgint4(task.MaxRuns), + }) + if err != nil { + return nil, fmt.Errorf("create scheduled task: %w", err) + } + return scheduledTaskFromDB(result), nil +} + +// GetByID returns a scheduled task by ID with ownership check. +func (r *ScheduledTaskRepository) GetByID(ctx context.Context, id uuid.UUID, publicKey string) (*types.ScheduledTask, error) { + task, err := r.q.GetScheduledTaskByID(ctx, &queries.GetScheduledTaskByIDParams{ + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get scheduled task: %w", err) + } + return scheduledTaskFromDB(task), nil +} + +// ListActive returns all active scheduled tasks for a user. +func (r *ScheduledTaskRepository) ListActive(ctx context.Context, publicKey string) ([]types.ScheduledTask, error) { + tasks, err := r.q.ListActiveScheduledTasks(ctx, publicKey) + if err != nil { + return nil, fmt.Errorf("list active scheduled tasks: %w", err) + } + return scheduledTasksFromDB(tasks), nil +} + +// CountActive returns the count of active tasks for a user. +func (r *ScheduledTaskRepository) CountActive(ctx context.Context, publicKey string) (int64, error) { + count, err := r.q.CountActiveTasksByPublicKey(ctx, publicKey) + if err != nil { + return 0, fmt.Errorf("count active tasks: %w", err) + } + return count, nil +} + +// Update updates a scheduled task with partial fields. +func (r *ScheduledTaskRepository) Update(ctx context.Context, id uuid.UUID, publicKey string, params *UpdateScheduledTaskParams) (*types.ScheduledTask, error) { + result, err := r.q.UpdateScheduledTask(ctx, &queries.UpdateScheduledTaskParams{ + ID: uuidToPgtype(id), + PublicKey: publicKey, + Intent: stringPtrToPgtext(params.Intent), + Context: params.Context, + NextRunAt: timePtrToPgtimestamptz(params.NextRunAt), + IntervalSeconds: int32PtrToPgint4(params.IntervalSeconds), + MaxRuns: int32PtrToPgint4(params.MaxRuns), + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("update scheduled task: %w", err) + } + return scheduledTaskFromDB(result), nil +} + +// Cancel cancels a scheduled task. +func (r *ScheduledTaskRepository) Cancel(ctx context.Context, id uuid.UUID, publicKey string) error { + rowsAffected, err := r.q.CancelScheduledTask(ctx, &queries.CancelScheduledTaskParams{ + ID: uuidToPgtype(id), + PublicKey: publicKey, + }) + if err != nil { + return fmt.Errorf("cancel scheduled task: %w", err) + } + if rowsAffected == 0 { + return ErrNotFound + } + return nil +} + +// ClaimDueTasks atomically claims all due tasks for execution. +func (r *ScheduledTaskRepository) ClaimDueTasks(ctx context.Context) ([]types.ScheduledTask, error) { + tasks, err := r.q.ClaimDueTasks(ctx) + if err != nil { + return nil, fmt.Errorf("claim due tasks: %w", err) + } + return scheduledTasksFromDB(tasks), nil +} + +// AdvanceTask increments run count and updates next_run_at and status. +func (r *ScheduledTaskRepository) AdvanceTask(ctx context.Context, id uuid.UUID, nextRunAt pgtype.Timestamptz, status types.TaskStatus) error { + err := r.q.AdvanceTask(ctx, &queries.AdvanceTaskParams{ + ID: uuidToPgtype(id), + NextRunAt: nextRunAt, + Status: queries.AgentTaskStatus(status), + }) + if err != nil { + return fmt.Errorf("advance task: %w", err) + } + return nil +} + +// CreateTaskRun creates a new task run record. +func (r *ScheduledTaskRepository) CreateTaskRun(ctx context.Context, taskID uuid.UUID) (*types.TaskRun, error) { + run, err := r.q.CreateTaskRun(ctx, uuidToPgtype(taskID)) + if err != nil { + return nil, fmt.Errorf("create task run: %w", err) + } + return taskRunFromDB(run), nil +} + +// CompleteTaskRun completes a task run with result data. +func (r *ScheduledTaskRepository) CompleteTaskRun(ctx context.Context, id uuid.UUID, status types.TaskRunStatus, result []byte, runError *string, conversationID *uuid.UUID) error { + convID := pgtype.UUID{Valid: false} + if conversationID != nil { + convID = uuidToPgtype(*conversationID) + } + err := r.q.CompleteTaskRun(ctx, &queries.CompleteTaskRunParams{ + ID: uuidToPgtype(id), + Status: queries.AgentTaskRunStatus(status), + Result: result, + Error: stringPtrToPgtext(runError), + ConversationID: convID, + }) + if err != nil { + return fmt.Errorf("complete task run: %w", err) + } + return nil +} + +// UpdateScheduledTaskParams holds optional fields for partial updates. +type UpdateScheduledTaskParams struct { + Intent *string + Context []byte + NextRunAt *time.Time + IntervalSeconds *int32 + MaxRuns *int32 +} + +// Helper functions + +func timePtrToPgtimestamptz(t *time.Time) pgtype.Timestamptz { + if t == nil { + return pgtype.Timestamptz{Valid: false} + } + return pgtype.Timestamptz{Time: *t, Valid: true} +} diff --git a/internal/storage/postgres/schema/schema.sql b/internal/storage/postgres/schema/schema.sql index ece039b..dff7e80 100644 --- a/internal/storage/postgres/schema/schema.sql +++ b/internal/storage/postgres/schema/schema.sql @@ -9,6 +9,9 @@ CREATE TABLE agent_conversations ( title TEXT, summary TEXT, summary_up_to TIMESTAMPTZ, + ecdsa_public_key TEXT, + eddsa_public_key TEXT, + chaincode_hex TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), archived_at TIMESTAMPTZ @@ -34,3 +37,39 @@ CREATE TABLE agent_user_memories ( content TEXT NOT NULL DEFAULT '', updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); + +CREATE TYPE agent_task_status AS ENUM ('active', 'paused', 'completed', 'cancelled'); +CREATE TYPE agent_task_run_status AS ENUM ('running', 'success', 'failed'); + +CREATE TABLE agent_scheduled_tasks ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + public_key VARCHAR(66) NOT NULL, + conversation_id UUID REFERENCES agent_conversations(id) ON DELETE SET NULL, + intent TEXT NOT NULL, + context JSONB NOT NULL DEFAULT '{}', + next_run_at TIMESTAMPTZ NOT NULL, + interval_seconds INT, + max_runs INT, + run_count INT NOT NULL DEFAULT 0, + status agent_task_status NOT NULL DEFAULT 'active', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX idx_agent_scheduled_tasks_public_key ON agent_scheduled_tasks(public_key); +CREATE INDEX idx_agent_scheduled_tasks_due ON agent_scheduled_tasks(next_run_at) WHERE status = 'active'; + +CREATE TABLE agent_task_runs ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + task_id UUID NOT NULL REFERENCES agent_scheduled_tasks(id) ON DELETE CASCADE, + conversation_id UUID REFERENCES agent_conversations(id) ON DELETE SET NULL, + status agent_task_run_status NOT NULL DEFAULT 'running', + result JSONB, + error TEXT, + notified BOOLEAN NOT NULL DEFAULT FALSE, + started_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + finished_at TIMESTAMPTZ +); + +CREATE INDEX idx_agent_task_runs_task_id ON agent_task_runs(task_id); +CREATE INDEX idx_agent_task_runs_unnotified ON agent_task_runs(notified) WHERE notified = FALSE; diff --git a/internal/storage/postgres/sqlc/conversations.sql b/internal/storage/postgres/sqlc/conversations.sql index d6c3c9c..32b9473 100644 --- a/internal/storage/postgres/sqlc/conversations.sql +++ b/internal/storage/postgres/sqlc/conversations.sql @@ -37,3 +37,12 @@ WHERE id = $3 AND public_key = $4; -- name: GetConversationSummaryWithCursor :one SELECT summary, summary_up_to FROM agent_conversations WHERE id = $1 AND public_key = $2; + +-- name: UpdateVaultInfo :execrows +UPDATE agent_conversations +SET ecdsa_public_key = $1, eddsa_public_key = $2, chaincode_hex = $3, updated_at = NOW() +WHERE id = $4 AND public_key = $5 AND archived_at IS NULL; + +-- name: GetVaultInfo :one +SELECT ecdsa_public_key, eddsa_public_key, chaincode_hex FROM agent_conversations +WHERE id = $1 AND public_key = $2 AND archived_at IS NULL; diff --git a/internal/storage/postgres/sqlc/scheduled_tasks.sql b/internal/storage/postgres/sqlc/scheduled_tasks.sql new file mode 100644 index 0000000..e5adec1 --- /dev/null +++ b/internal/storage/postgres/sqlc/scheduled_tasks.sql @@ -0,0 +1,61 @@ +-- name: CreateScheduledTask :one +INSERT INTO agent_scheduled_tasks ( + public_key, conversation_id, intent, context, + next_run_at, interval_seconds, max_runs, status +) VALUES ( + $1, $2, $3, $4, $5, $6, $7, 'active' +) RETURNING *; + +-- name: GetScheduledTaskByID :one +SELECT * FROM agent_scheduled_tasks +WHERE id = $1 AND public_key = $2; + +-- name: ListActiveScheduledTasks :many +SELECT * FROM agent_scheduled_tasks +WHERE public_key = $1 AND status = 'active' +ORDER BY created_at DESC; + +-- name: CountActiveTasksByPublicKey :one +SELECT COUNT(*) FROM agent_scheduled_tasks +WHERE public_key = $1 AND status = 'active'; + +-- name: UpdateScheduledTask :one +UPDATE agent_scheduled_tasks SET + intent = COALESCE(sqlc.narg('intent'), intent), + context = COALESCE(sqlc.narg('context'), context), + next_run_at = COALESCE(sqlc.narg('next_run_at'), next_run_at), + interval_seconds = COALESCE(sqlc.narg('interval_seconds'), interval_seconds), + max_runs = COALESCE(sqlc.narg('max_runs'), max_runs), + updated_at = NOW() +WHERE id = @id AND public_key = @public_key AND status = 'active' +RETURNING *; + +-- name: CancelScheduledTask :execrows +UPDATE agent_scheduled_tasks +SET status = 'cancelled', updated_at = NOW() +WHERE id = $1 AND public_key = $2 AND status IN ('active', 'paused'); + +-- name: ClaimDueTasks :many +UPDATE agent_scheduled_tasks +SET status = 'paused', updated_at = NOW() +WHERE status = 'active' AND next_run_at <= NOW() +RETURNING *; + +-- name: AdvanceTask :exec +UPDATE agent_scheduled_tasks +SET run_count = run_count + 1, + next_run_at = $2, + status = $3, + updated_at = NOW() +WHERE id = $1; + +-- name: CreateTaskRun :one +INSERT INTO agent_task_runs (task_id, status) +VALUES ($1, 'running') +RETURNING *; + +-- name: CompleteTaskRun :exec +UPDATE agent_task_runs +SET status = $2, result = $3, error = $4, + conversation_id = $5, finished_at = NOW() +WHERE id = $1; diff --git a/internal/types/conversation.go b/internal/types/conversation.go index 08306dd..cfc62b7 100644 --- a/internal/types/conversation.go +++ b/internal/types/conversation.go @@ -23,11 +23,19 @@ type Conversation struct { Title *string `json:"title"` Summary *string `json:"summary,omitempty"` SummaryUpTo *time.Time `json:"summary_up_to,omitempty"` + VaultInfo *VaultInfo `json:"vault_info,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` ArchivedAt *time.Time `json:"archived_at,omitempty"` } +// VaultInfo holds the cryptographic keys for a vault bound to a conversation. +type VaultInfo struct { + ECDSAPublicKey string `json:"ecdsa_public_key"` + EDDSAPublicKey string `json:"eddsa_public_key"` + ChaincodeHex string `json:"chaincode_hex"` +} + // Message represents a single message in a conversation. type Message struct { ID uuid.UUID `json:"id"` diff --git a/internal/types/scheduled_task.go b/internal/types/scheduled_task.go new file mode 100644 index 0000000..bdc7ce5 --- /dev/null +++ b/internal/types/scheduled_task.go @@ -0,0 +1,57 @@ +package types + +import ( + "encoding/json" + "time" + + "github.com/google/uuid" +) + +type TaskStatus string + +const ( + TaskStatusActive TaskStatus = "active" + TaskStatusPaused TaskStatus = "paused" + TaskStatusCompleted TaskStatus = "completed" + TaskStatusCancelled TaskStatus = "cancelled" +) + +type TaskRunStatus string + +const ( + TaskRunRunning TaskRunStatus = "running" + TaskRunSuccess TaskRunStatus = "success" + TaskRunFailed TaskRunStatus = "failed" +) + +type ScheduledTask struct { + ID uuid.UUID `json:"id"` + PublicKey string `json:"public_key"` + ConversationID *uuid.UUID `json:"conversation_id,omitempty"` + Intent string `json:"intent"` + Context json.RawMessage `json:"context"` + NextRunAt time.Time `json:"next_run_at"` + IntervalSeconds *int32 `json:"interval_seconds,omitempty"` + MaxRuns *int32 `json:"max_runs,omitempty"` + RunCount int32 `json:"run_count"` + Status TaskStatus `json:"status"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// IsRecurring returns true if the task has an interval (recurring schedule). +func (t *ScheduledTask) IsRecurring() bool { + return t.IntervalSeconds != nil && *t.IntervalSeconds > 0 +} + +type TaskRun struct { + ID uuid.UUID `json:"id"` + TaskID uuid.UUID `json:"task_id"` + ConversationID *uuid.UUID `json:"conversation_id,omitempty"` + Status TaskRunStatus `json:"status"` + Result json.RawMessage `json:"result,omitempty"` + Error *string `json:"error,omitempty"` + Notified bool `json:"notified"` + StartedAt time.Time `json:"started_at"` + FinishedAt *time.Time `json:"finished_at,omitempty"` +}