diff --git a/cmd/server/main.go b/cmd/server/main.go index d2b177a..5d7aec8 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -19,7 +19,6 @@ import ( "github.com/vultisig/agent-backend/internal/config" "github.com/vultisig/agent-backend/internal/mcp" "github.com/vultisig/agent-backend/internal/service/agent" - mcpclient "github.com/vultisig/agent-backend/internal/service/mcp" "github.com/vultisig/agent-backend/internal/service/plugin" "github.com/vultisig/agent-backend/internal/service/verifier" "github.com/vultisig/agent-backend/internal/storage/postgres" @@ -79,21 +78,35 @@ func main() { cacheTTL := time.Duration(cfg.MCP.ToolCacheTTLSec) * time.Second mcpClient := mcp.NewClient(cfg.MCP.ServerURL, cacheTTL, logger) + const mcpMaxRetries = 10 + const mcpRetryInterval = 3 * time.Second + for attempt := 1; attempt <= mcpMaxRetries; attempt++ { + mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second) + err := mcpClient.Initialize(mcpCtx) + mcpCancel() + if err == nil { + break + } + logger.WithError(err).WithField("attempt", attempt).Warn("failed to initialize mcp client") + if attempt == mcpMaxRetries { + logger.Warn("mcp client init exhausted retries, continuing without mcp tools") + } else { + time.Sleep(mcpRetryInterval) + } + } + mcpCtx, mcpCancel := context.WithTimeout(ctx, 10*time.Second) defer mcpCancel() - if err := mcpClient.Initialize(mcpCtx); err != nil { - logger.WithError(err).Warn("failed to initialize mcp client, continuing without mcp tools") + tools, err := mcpClient.ListTools(mcpCtx) + if err != nil { + logger.WithError(err).Warn("failed to list mcp tools, continuing without mcp tools") } else { - tools, err := mcpClient.ListTools(mcpCtx) - if err != nil { - logger.WithError(err).Warn("failed to list mcp tools, continuing without mcp tools") - } else { - logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") - mcpProvider = mcpClient - } + logger.WithField("tool_count", len(tools)).Info("mcp tools loaded") + mcpProvider = mcpClient + } - // Pre-warm skill cache (non-fatal) + if mcpProvider != nil { skills, err := mcpClient.ListSkills(mcpCtx) if err != nil { logger.WithError(err).Warn("failed to list mcp skills, continuing without skills") @@ -103,16 +116,11 @@ func main() { } } - // Initialize MCP swap tx builder (optional) - var swapTxBuilder agent.SwapTxBuilder - if cfg.MCP.URL != "" { - mcpCl := mcpclient.NewClient(cfg.MCP.URL) - swapTxBuilder = mcpclient.NewSwapTxAdapter(mcpCl) - logger.WithField("url", cfg.MCP.URL).Info("MCP swap tx builder enabled") - } + // Initialize action router for intercepting actions (e.g. search_token → MCP) + actionRouter := agent.ParseActionRoutes(cfg.ActionRoutes, mcpProvider, logger) // Initialize agent service - agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, mcpProvider, swapTxBuilder, logger, cfg.AI.SummaryModel, cfg.Context) + agentService := agent.NewAgentService(aiClient, msgRepo, convRepo, memRepo, redisClient, verifierClient, pluginService, mcpProvider, actionRouter, logger, cfg.AI.SummaryModel, cfg.Context) // Initialize API server server := api.NewServer( @@ -174,7 +182,6 @@ func main() { agentGroup.POST("/conversations/:id", server.GetConversation) agentGroup.DELETE("/conversations/:id", server.DeleteConversation) agentGroup.POST("/conversations/:id/messages", server.SendMessage, api.RateLimitMiddleware(redisClient)) - agentGroup.POST("/conversations/:id/tx/build", server.BuildTx) agentGroup.POST("/starters", server.GetStarters) // Start server diff --git a/internal/api/swap.go b/internal/api/swap.go deleted file mode 100644 index 330049c..0000000 --- a/internal/api/swap.go +++ /dev/null @@ -1,37 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/google/uuid" - "github.com/labstack/echo/v4" - - "github.com/vultisig/agent-backend/internal/service/agent" -) - -func (s *Server) BuildTx(c echo.Context) error { - idStr := c.Param("id") - convID, err := uuid.Parse(idStr) - if err != nil { - return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid conversation id"}) - } - - var req agent.BuildTxRequest - err = c.Bind(&req) - if err != nil { - return c.JSON(http.StatusBadRequest, ErrorResponse{Error: "invalid request body"}) - } - - authPublicKey := GetPublicKey(c) - if req.PublicKey != authPublicKey { - return c.JSON(http.StatusForbidden, ErrorResponse{Error: "public key mismatch"}) - } - - resp, err := s.agentService.BuildTx(c.Request().Context(), convID, &req) - if err != nil { - s.logger.WithError(err).Error("failed to build tx") - return c.JSON(http.StatusInternalServerError, ErrorResponse{Error: "failed to build tx"}) - } - - return c.JSON(http.StatusOK, resp) -} diff --git a/internal/config/config.go b/internal/config/config.go index 4c231e6..77e3cad 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,15 +6,16 @@ import ( // Config holds all configuration for the agent-backend service. type Config struct { - LogFormat string `envconfig:"LOG_FORMAT" default:"json"` - Server ServerConfig - Database DatabaseConfig - Redis RedisConfig - AuthCache AuthCacheConfig - AI AIConfig - Context ContextConfig - Verifier VerifierConfig - MCP MCPConfig + LogFormat string `envconfig:"LOG_FORMAT" default:"json"` + ActionRoutes string `envconfig:"ACTION_ROUTES" default:""` + Server ServerConfig + Database DatabaseConfig + Redis RedisConfig + AuthCache AuthCacheConfig + AI AIConfig + Context ContextConfig + Verifier VerifierConfig + MCP MCPConfig } // ServerConfig holds HTTP server configuration. diff --git a/internal/service/agent/agent.go b/internal/service/agent/agent.go index 18ffcd2..3ff0ece 100644 --- a/internal/service/agent/agent.go +++ b/internal/service/agent/agent.go @@ -39,34 +39,6 @@ type MCPToolProvider interface { ReadSkill(ctx context.Context, slug string) (string, error) } -type SwapTxBuilder interface { - BuildSwapTx(ctx context.Context, req SwapTxBuildRequest) (*SwapTxBuildResponse, error) -} - -type SwapTxBuildRequest struct { - FromChain string `json:"from_chain"` - FromSymbol string `json:"from_symbol"` - FromAddress string `json:"from_address,omitempty"` - FromDecimals *int `json:"from_decimals,omitempty"` - ToChain string `json:"to_chain"` - ToSymbol string `json:"to_symbol"` - ToAddress string `json:"to_address,omitempty"` - ToDecimals *int `json:"to_decimals,omitempty"` - Amount string `json:"amount"` - Sender string `json:"sender"` - Destination string `json:"destination"` -} - -type SwapTxBuildResponse struct { - Provider string `json:"provider"` - ExpectedOutput string `json:"expected_output"` - MinimumOutput string `json:"minimum_output"` - NeedsApproval bool `json:"needs_approval"` - ApprovalTx json.RawMessage `json:"approval_tx,omitempty"` - SwapTx json.RawMessage `json:"swap_tx"` - Memo string `json:"memo,omitempty"` -} - type AgentService struct { ai *ai.Client msgRepo *postgres.MessageRepository @@ -76,7 +48,7 @@ type AgentService struct { verifier *verifier.Client pluginProvider PluginSkillsProvider mcpProvider MCPToolProvider - swapTxBuilder SwapTxBuilder + router *ActionRouter logger *logrus.Logger summaryModel string windowSize int @@ -99,7 +71,7 @@ func NewAgentService( verifierClient *verifier.Client, pluginProvider PluginSkillsProvider, mcpProvider MCPToolProvider, - swapTxBuilder SwapTxBuilder, + router *ActionRouter, logger *logrus.Logger, summaryModel string, ctxCfg config.ContextConfig, @@ -113,7 +85,7 @@ func NewAgentService( verifier: verifierClient, pluginProvider: pluginProvider, mcpProvider: mcpProvider, - swapTxBuilder: swapTxBuilder, + router: router, logger: logger, summaryModel: summaryModel, windowSize: ctxCfg.WindowSize, @@ -162,8 +134,14 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub } fullCtx := s.resolveContext(ctx, convID, req.Context) + resolvedTools := s.resolveTools(ctx, convID, req.Tools) + + var actionsTable string + if len(resolvedTools) > 0 { + actionsTable = BuildActionsTable(resolvedTools) + } - basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx), actionsTable) if conv.VaultInfo != nil { basePrompt += BuildVaultInfoSection(conv.VaultInfo) } @@ -189,7 +167,7 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub Content: userContent, }) - tools := agentTools() + tools := s.buildAgentTools(resolvedTools) tools = append(tools, s.memoryTools()...) if s.mcpProvider != nil { mcpTools := s.mcpProvider.GetTools(ctx) @@ -207,7 +185,6 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub } tools = append(tools, mcpTools...) - // Add get_skill tool if skills are available if s.mcpProvider.SkillSummary(ctx) != "" { tools = append(tools, GetSkillTool) } @@ -304,7 +281,9 @@ func (s *AgentService) ProcessMessage(ctx context.Context, convID uuid.UUID, pub if err != nil { return nil, err } - resp.Tokens = tokens + if resp.Tokens == nil { + resp.Tokens = tokens + } return resp, nil } if textContent != "" { @@ -364,8 +343,14 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI } fullCtx := s.resolveContext(ctx, convID, req.Context) + resolvedTools := s.resolveTools(ctx, convID, req.Tools) - basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx)) + var actionsTable string + if len(resolvedTools) > 0 { + actionsTable = BuildActionsTable(resolvedTools) + } + + basePrompt := BuildFullPrompt(fullCtx, s.getPluginSkills(ctx), actionsTable) if conv.VaultInfo != nil { basePrompt += BuildVaultInfoSection(conv.VaultInfo) } @@ -391,7 +376,7 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI Content: userContent, }) - tools := agentTools() + tools := s.buildAgentTools(resolvedTools) tools = append(tools, s.memoryTools()...) if s.mcpProvider != nil { mcpTools := s.mcpProvider.GetTools(ctx) @@ -463,6 +448,20 @@ func (s *AgentService) ProcessMessageStream(ctx context.Context, convID uuid.UUI continue } + if s.isMCPTool(tc.Function.Name) { + eventCh <- SSEEvent{ + Event: "actions", + Data: ActionsPayload{Actions: []Action{{ + ID: "mcp_" + tc.ID, + Type: "mcp_status", + Title: tc.Function.Name, + Params: map[string]any{ + "original_action": tc.Function.Name, + }, + }}}, + } + } + result, err := s.executeTool(ctx, convID, tc.Function.Name, json.RawMessage(tc.Function.Arguments), req) if err != nil { result = jsonError(err.Error()) @@ -537,51 +536,8 @@ func (s *AgentService) resolveContentType(req *SendMessageRequest) string { } func (s *AgentService) buildLoopResponse(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse, window *conversationWindow) (*SendMessageResponse, error) { - var suggestions []Suggestion - for _, ts := range toolResp.Suggestions { - suggID := "sug_" + uuid.New().String() - sugg := Suggestion{ - ID: suggID, - PluginID: ts.PluginID, - Title: ts.Title, - Description: ts.Description, - } - suggestions = append(suggestions, sugg) - - suggJSON, err := json.Marshal(sugg) - if err != nil { - s.logger.WithError(err).Warn("failed to marshal suggestion") - continue - } - if err := s.redis.Set(ctx, suggID, string(suggJSON), suggestionTTL); err != nil { - s.logger.WithError(err).Warn("failed to store suggestion in redis") - } - } - - var actions []Action - for _, ta := range toolResp.Actions { - actID := "act_" + uuid.New().String() - act := Action{ - ID: actID, - Type: ta.Type, - Title: ta.Title, - Description: ta.Description, - Params: ta.Params, - AutoExecute: ta.AutoExecute, - } - actions = append(actions, act) - - actJSON, err := json.Marshal(act) - if err != nil { - s.logger.WithError(err).Warn("failed to marshal action") - continue - } - if err := s.redis.Set(ctx, actID, string(actJSON), actionTTL); err != nil { - s.logger.WithError(err).Warn("failed to store action in redis") - } - } - - txReady := s.interceptBuildSwapTx(ctx, convID, req, toolResp) + suggestions := s.processSuggestions(ctx, toolResp.Suggestions) + processed := s.processActions(ctx, toolResp, req.Context) metadataMap := map[string]any{ "intent": toolResp.Intent, @@ -589,8 +545,8 @@ func (s *AgentService) buildLoopResponse(ctx context.Context, convID uuid.UUID, if len(suggestions) > 0 { metadataMap["suggestions"] = suggestions } - if len(actions) > 0 { - metadataMap["actions"] = actions + if len(processed.actions) > 0 { + metadataMap["actions"] = processed.actions } metadata, _ := json.Marshal(metadataMap) @@ -618,15 +574,15 @@ func (s *AgentService) buildLoopResponse(ctx context.Context, convID uuid.UUID, Message: *assistantMsg, Title: titlePtr, Suggestions: suggestions, - Actions: actions, - TxReady: txReady, + Actions: processed.actions, + Tokens: processed.tokens, } if req.ActionResult != nil && req.ActionResult.Action == "install_plugin" && req.ActionResult.Success { s.autoContinueAfterInstall(ctx, convID, req, window, resp) } - s.logger.WithField("num_actions", len(actions)).Info("sending response to desktop") + s.logger.WithField("num_actions", len(processed.actions)).Info("sending response to desktop") return resp, nil } @@ -645,9 +601,78 @@ func (s *AgentService) buildTextResponse(ctx context.Context, convID uuid.UUID, }, nil } -func (s *AgentService) emitLoopResponse(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse, window *conversationWindow, eventCh chan<- SSEEvent) { +type processedActions struct { + actions []Action + tokens *TokenSearchResult +} + +func (s *AgentService) processActions(ctx context.Context, toolResp *ToolResponse, msgCtx *MessageContext) processedActions { + var result processedActions + + hasBuildTx := false + for _, ta := range toolResp.Actions { + if ta.Type == "build_swap_tx" || ta.Type == "build_send_tx" || ta.Type == "build_custom_tx" { + hasBuildTx = true + break + } + } + + for _, ta := range toolResp.Actions { + if hasBuildTx && ta.Type == "sign_tx" { + s.logger.Warn("stripping sign_tx co-emitted with build action") + continue + } + if s.router != nil { + interceptResult, err := s.router.Route(ctx, ta, msgCtx) + if err != nil { + s.logger.WithError(err).WithField("action_type", ta.Type).Warn("action interceptor error, dropping action") + toolResp.Response += fmt.Sprintf("\n\n[Error processing %s: %s]", ta.Type, err.Error()) + continue + } + if interceptResult != nil { + if interceptResult.Tokens != nil { + result.tokens = interceptResult.Tokens + } + for _, replacement := range interceptResult.Actions { + act := s.storeAction(ctx, replacement) + result.actions = append(result.actions, act) + } + continue + } + } + + act := s.storeAction(ctx, ta) + result.actions = append(result.actions, act) + } + + return result +} + +func (s *AgentService) storeAction(ctx context.Context, ta ToolAction) Action { + actID := "act_" + uuid.New().String() + act := Action{ + ID: actID, + Type: ta.Type, + Title: ta.Title, + Description: ta.Description, + Params: ta.Params, + } + + actJSON, err := json.Marshal(act) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal action") + return act + } + err = s.redis.Set(ctx, actID, string(actJSON), actionTTL) + if err != nil { + s.logger.WithError(err).Warn("failed to store action in redis") + } + return act +} + +func (s *AgentService) processSuggestions(ctx context.Context, toolSuggestions []ToolSuggestion) []Suggestion { var suggestions []Suggestion - for _, ts := range toolResp.Suggestions { + for _, ts := range toolSuggestions { suggID := "sug_" + uuid.New().String() sugg := Suggestion{ ID: suggID, @@ -662,35 +687,17 @@ func (s *AgentService) emitLoopResponse(ctx context.Context, convID uuid.UUID, r s.logger.WithError(err).Warn("failed to marshal suggestion") continue } - if err := s.redis.Set(ctx, suggID, string(suggJSON), suggestionTTL); err != nil { - s.logger.WithError(err).Warn("failed to store suggestion in redis") - } - } - - var actions []Action - for _, ta := range toolResp.Actions { - actID := "act_" + uuid.New().String() - act := Action{ - ID: actID, - Type: ta.Type, - Title: ta.Title, - Description: ta.Description, - Params: ta.Params, - AutoExecute: ta.AutoExecute, - } - actions = append(actions, act) - - actJSON, err := json.Marshal(act) + err = s.redis.Set(ctx, suggID, string(suggJSON), suggestionTTL) if err != nil { - s.logger.WithError(err).Warn("failed to marshal action") - continue - } - if err := s.redis.Set(ctx, actID, string(actJSON), actionTTL); err != nil { - s.logger.WithError(err).Warn("failed to store action in redis") + s.logger.WithError(err).Warn("failed to store suggestion in redis") } } + return suggestions +} - txReady := s.interceptBuildSwapTx(ctx, convID, req, toolResp) +func (s *AgentService) emitLoopResponse(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse, window *conversationWindow, eventCh chan<- SSEEvent) { + suggestions := s.processSuggestions(ctx, toolResp.Suggestions) + processed := s.processActions(ctx, toolResp, req.Context) metadataMap := map[string]any{ "intent": toolResp.Intent, @@ -698,8 +705,8 @@ func (s *AgentService) emitLoopResponse(ctx context.Context, convID uuid.UUID, r if len(suggestions) > 0 { metadataMap["suggestions"] = suggestions } - if len(actions) > 0 { - metadataMap["actions"] = actions + if len(processed.actions) > 0 { + metadataMap["actions"] = processed.actions } metadata, _ := json.Marshal(metadataMap) @@ -728,11 +735,11 @@ func (s *AgentService) emitLoopResponse(ctx context.Context, convID uuid.UUID, r if len(suggestions) > 0 { eventCh <- SSEEvent{Event: "suggestions", Data: SuggestionsPayload{Suggestions: suggestions}} } - if len(actions) > 0 { - eventCh <- SSEEvent{Event: "actions", Data: ActionsPayload{Actions: actions}} + if len(processed.actions) > 0 { + eventCh <- SSEEvent{Event: "actions", Data: ActionsPayload{Actions: processed.actions}} } - if txReady != nil { - eventCh <- SSEEvent{Event: "tx_ready", Data: txReady} + if processed.tokens != nil { + eventCh <- SSEEvent{Event: "tokens", Data: processed.tokens} } eventCh <- SSEEvent{Event: "message", Data: MessagePayload{Message: *assistantMsg}} @@ -846,163 +853,6 @@ func (s *AgentService) getPluginSkills(ctx context.Context) []PluginSkill { return s.pluginProvider.GetSkills(ctx) } -func (s *AgentService) BuildTx(ctx context.Context, convID uuid.UUID, req *BuildTxRequest) (*BuildTxResponse, error) { - if s.swapTxBuilder == nil { - return nil, fmt.Errorf("swap builder not configured") - } - - fullCtx := s.resolveContext(ctx, convID, req.Context) - - sendReq := &SendMessageRequest{ - PublicKey: req.PublicKey, - Context: fullCtx, - } - toolResp := &ToolResponse{ - Actions: []ToolAction{ - { - Type: "build_tx", - Params: req.Params, - }, - }, - } - - txReady := s.interceptBuildSwapTx(ctx, convID, sendReq, toolResp) - - resp := &BuildTxResponse{ - TxReady: txReady, - } - - for _, ta := range toolResp.Actions { - resp.Actions = append(resp.Actions, Action{ - ID: "act_" + uuid.New().String(), - Type: ta.Type, - Title: ta.Title, - Description: ta.Description, - Params: ta.Params, - AutoExecute: ta.AutoExecute, - }) - } - - if txReady == nil { - msg := strings.TrimSpace(toolResp.Response) - if len(resp.Actions) > 0 { - resp.Message = msg - } else { - resp.Error = msg - } - } - - return resp, nil -} - -func (s *AgentService) interceptBuildSwapTx(ctx context.Context, convID uuid.UUID, req *SendMessageRequest, toolResp *ToolResponse) *TxReady { - for i, action := range toolResp.Actions { - if action.Type != "build_tx" { - continue - } - - fromChain := getStringParam(action.Params, "from_chain") - fromSymbol := getStringParam(action.Params, "from_symbol") - toChain := getStringParam(action.Params, "to_chain") - toSymbol := getStringParam(action.Params, "to_symbol") - amount := getStringParam(action.Params, "amount") - - if fromChain == "" || fromSymbol == "" || toChain == "" || toSymbol == "" || amount == "" { - s.logger.Warn("build_tx missing required params") - toolResp.Actions = append(toolResp.Actions[:i], toolResp.Actions[i+1:]...) - toolResp.Response += "\n\nCould not build swap transaction: missing required parameters." - continue - } - - var sender, destination string - if req.Context != nil && req.Context.Addresses != nil { - sender = findAddress(req.Context.Addresses, fromChain) - destination = findAddress(req.Context.Addresses, toChain) - } - if sender == "" || destination == "" { - s.logger.WithField("from_chain", fromChain).WithField("to_chain", toChain). - WithField("addresses", req.Context.Addresses). - Warn("build_tx: missing sender/destination addresses") - toolResp.Actions = append(toolResp.Actions[:i], toolResp.Actions[i+1:]...) - toolResp.Response += "\n\nCould not build swap transaction: missing wallet addresses." - continue - } - - _, fromDecimals := resolveTokenFromContext(req.Context, fromChain, fromSymbol) - _, toDecimals := resolveTokenFromContext(req.Context, toChain, toSymbol) - - buildReq := SwapTxBuildRequest{ - FromChain: fromChain, - FromSymbol: fromSymbol, - FromDecimals: fromDecimals, - ToChain: toChain, - ToSymbol: toSymbol, - ToDecimals: toDecimals, - Amount: amount, - Sender: sender, - Destination: destination, - } - - buildResp, err := s.swapTxBuilder.BuildSwapTx(ctx, buildReq) - if err != nil { - s.logger.WithError(err).Warn("build_tx failed") - toolResp.Actions = append(toolResp.Actions[:i], toolResp.Actions[i+1:]...) - toolResp.Response += "\n\nFailed to build swap transaction: " + err.Error() - continue - } - - toolResp.Actions[i].Type = "sign_swap_tx" - - var txFields struct { - To string `json:"to"` - Value string `json:"value"` - Data string `json:"data"` - } - unmarshalErr := json.Unmarshal(buildResp.SwapTx, &txFields) - if unmarshalErr == nil && txFields.Data != "" { - toolResp.Actions = append(toolResp.Actions, ToolAction{ - Type: "scan_tx", - AutoExecute: true, - Params: map[string]any{ - "chain": fromChain, - "from": sender, - "to": txFields.To, - "value": txFields.Value, - "data": txFields.Data, - }, - }) - } - - fromDec := 18 - if fromDecimals != nil { - fromDec = *fromDecimals - } - toDec := 18 - if toDecimals != nil { - toDec = *toDecimals - } - - return &TxReady{ - Provider: buildResp.Provider, - ExpectedOutput: buildResp.ExpectedOutput, - MinimumOutput: buildResp.MinimumOutput, - NeedsApproval: buildResp.NeedsApproval, - ApprovalTx: buildResp.ApprovalTx, - SwapTx: buildResp.SwapTx, - FromChain: fromChain, - FromSymbol: fromSymbol, - FromDecimals: fromDec, - ToChain: toChain, - ToSymbol: toSymbol, - ToDecimals: toDec, - Amount: amount, - Sender: sender, - Destination: destination, - } - } - - return nil -} func truncateTitle(title string, maxLen int) string { if len(title) <= maxLen { @@ -1011,42 +861,6 @@ func truncateTitle(title string, maxLen int) string { return title[:maxLen-3] + "..." } -func resolveTokenFromContext(msgCtx *MessageContext, chain, symbol string) (string, *int) { - if msgCtx == nil { - return "", nil - } - for _, coin := range msgCtx.Coins { - if strings.EqualFold(coin.Chain, chain) && strings.EqualFold(coin.Ticker, symbol) { - dec := coin.Decimals - return coin.ContractAddress, &dec - } - } - return "", nil -} - -func findAddress(addresses map[string]string, chain string) string { - if addr := addresses[chain]; addr != "" { - return addr - } - for k, v := range addresses { - if strings.EqualFold(k, chain) { - return v - } - } - return "" -} - -func getStringParam(params map[string]any, key string) string { - v, ok := params[key] - if !ok || v == nil { - return "" - } - str, ok := v.(string) - if !ok { - return fmt.Sprintf("%v", v) - } - return str -} func (s *AgentService) getConversationWindow(ctx context.Context, convID uuid.UUID, publicKey string) (*conversationWindow, error) { summary, cursor, err := s.convRepo.GetSummaryWithCursor(ctx, convID, publicKey) diff --git a/internal/service/agent/context.go b/internal/service/agent/context.go index c7d127b..f867c04 100644 --- a/internal/service/agent/context.go +++ b/internal/service/agent/context.go @@ -68,6 +68,52 @@ func (s *AgentService) resolveContext(ctx context.Context, convID uuid.UUID, msg return mergeContext(cached, msgCtx) } +const toolContextTTL = 24 * time.Hour + +func toolContextKey(convID uuid.UUID) string { + return fmt.Sprintf("tool_ctx:%s", convID) +} + +func (s *AgentService) cacheToolContext(ctx context.Context, convID uuid.UUID, tools []AppTool) { + if len(tools) == 0 { + return + } + + data, err := json.Marshal(tools) + if err != nil { + s.logger.WithError(err).Warn("failed to marshal tool context for cache") + return + } + + err = s.redis.Set(ctx, toolContextKey(convID), string(data), toolContextTTL) + if err != nil { + s.logger.WithError(err).Warn("failed to cache tool context in redis") + } +} + +func (s *AgentService) loadCachedToolContext(ctx context.Context, convID uuid.UUID) []AppTool { + raw, err := s.redis.Get(ctx, toolContextKey(convID)) + if err != nil || raw == "" { + return nil + } + + var tools []AppTool + err = json.Unmarshal([]byte(raw), &tools) + if err != nil { + s.logger.WithError(err).Warn("failed to unmarshal cached tool context") + return nil + } + return tools +} + +func (s *AgentService) resolveTools(ctx context.Context, convID uuid.UUID, tools []AppTool) []AppTool { + if len(tools) > 0 { + s.cacheToolContext(ctx, convID, tools) + return tools + } + return s.loadCachedToolContext(ctx, convID) +} + func mergeContext(cached *cachedVaultContext, msg *MessageContext) *MessageContext { if msg == nil { msg = &MessageContext{} diff --git a/internal/service/agent/executor.go b/internal/service/agent/executor.go index 8bf2785..1104d4c 100644 --- a/internal/service/agent/executor.go +++ b/internal/service/agent/executor.go @@ -14,6 +14,18 @@ import ( const suggestionTTL = 1 * time.Hour +func (s *AgentService) isMCPTool(name string) bool { + if s.mcpProvider == nil { + return false + } + for _, mcpName := range s.mcpProvider.ToolNames() { + if mcpName == name { + return true + } + } + return false +} + // executeTool dispatches a tool call to the appropriate handler. // Returns a JSON string result for Claude. Errors are returned as JSON {"error": "..."} so the LLM can communicate them naturally. func (s *AgentService) executeTool(ctx context.Context, convID uuid.UUID, name string, input json.RawMessage, req *SendMessageRequest) (string, error) { diff --git a/internal/service/agent/interceptor.go b/internal/service/agent/interceptor.go new file mode 100644 index 0000000..d99cf84 --- /dev/null +++ b/internal/service/agent/interceptor.go @@ -0,0 +1,155 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +type ActionInterceptor interface { + Intercept(ctx context.Context, action ToolAction, msgCtx *MessageContext) (*InterceptResult, error) +} + +type InterceptResult struct { + Actions []ToolAction + Tokens *TokenSearchResult +} + +type ActionRouter struct { + interceptors map[string]ActionInterceptor + logger *logrus.Logger +} + +func NewActionRouter(interceptors map[string]ActionInterceptor, logger *logrus.Logger) *ActionRouter { + return &ActionRouter{ + interceptors: interceptors, + logger: logger, + } +} + +func (r *ActionRouter) Route(ctx context.Context, action ToolAction, msgCtx *MessageContext) (*InterceptResult, error) { + interceptor, ok := r.interceptors[action.Type] + if !ok { + return nil, nil + } + + r.logger.WithField("action_type", action.Type).Debug("intercepting action") + return interceptor.Intercept(ctx, action, msgCtx) +} + +var mcpToolMapping = map[string]string{ + "search_token": "search_token", +} + +type MCPActionInterceptor struct { + mcpProvider MCPToolProvider + logger *logrus.Logger +} + +func NewMCPActionInterceptor(mcpProvider MCPToolProvider, logger *logrus.Logger) *MCPActionInterceptor { + return &MCPActionInterceptor{ + mcpProvider: mcpProvider, + logger: logger, + } +} + +func (m *MCPActionInterceptor) Intercept(ctx context.Context, action ToolAction, msgCtx *MessageContext) (*InterceptResult, error) { + route, ok := mcpToolMapping[action.Type] + if !ok { + return nil, fmt.Errorf("no MCP tool mapping for action type: %s", action.Type) + } + + params, err := json.Marshal(action.Params) + if err != nil { + return nil, fmt.Errorf("marshal action params: %w", err) + } + + m.logger.WithFields(logrus.Fields{ + "action_type": action.Type, + "mcp_tool": route, + }).Debug("calling MCP tool for intercepted action") + + result, err := m.mcpProvider.CallTool(ctx, route, params) + if err != nil { + return nil, fmt.Errorf("MCP tool %s failed: %w", route, err) + } + + statusAction := ToolAction{ + Type: "mcp_status", + Title: fmt.Sprintf("Processing %s via server", action.Type), + Params: map[string]any{ + "original_action": action.Type, + "mcp_tool": route, + "status": "completed", + }, + } + + switch action.Type { + case "search_token": + tokens := extractTokens(result) + if tokens == nil { + return nil, fmt.Errorf("failed to parse token results from MCP tool %s", route) + } + m.logger.WithField("token_count", len(tokens.Tokens)).Info("tokens extracted via action interceptor") + return &InterceptResult{ + Actions: []ToolAction{statusAction}, + Tokens: tokens, + }, nil + default: + return nil, fmt.Errorf("no result parser for action type: %s", action.Type) + } +} + +func ParseActionRoutes(raw string, mcpProvider MCPToolProvider, logger *logrus.Logger) *ActionRouter { + interceptors := make(map[string]ActionInterceptor) + + if raw == "" { + return NewActionRouter(interceptors, logger) + } + + var mcpInterceptor *MCPActionInterceptor + for _, pair := range strings.Split(raw, ",") { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + parts := strings.SplitN(pair, ":", 2) + if len(parts) != 2 { + logger.WithField("entry", pair).Warn("invalid ACTION_ROUTES entry, expected action_type:target") + continue + } + + actionType := strings.TrimSpace(parts[0]) + target := strings.TrimSpace(parts[1]) + + switch target { + case "mcp": + if mcpProvider == nil { + logger.WithField("action_type", actionType).Warn("cannot route to MCP: no MCP provider configured") + continue + } + if _, ok := mcpToolMapping[actionType]; !ok { + logger.WithField("action_type", actionType).Warn("no MCP tool mapping for action type") + continue + } + if mcpInterceptor == nil { + mcpInterceptor = NewMCPActionInterceptor(mcpProvider, logger) + } + interceptors[actionType] = mcpInterceptor + logger.WithFields(logrus.Fields{ + "action_type": actionType, + "mcp_tool": mcpToolMapping[actionType], + }).Info("action routed to MCP") + default: + logger.WithFields(logrus.Fields{ + "action_type": actionType, + "target": target, + }).Warn("unknown route target, ignoring") + } + } + + return NewActionRouter(interceptors, logger) +} diff --git a/internal/service/agent/prompt.go b/internal/service/agent/prompt.go index d15ca02..86f67ac 100644 --- a/internal/service/agent/prompt.go +++ b/internal/service/agent/prompt.go @@ -1,6 +1,7 @@ package agent import ( + "encoding/json" "fmt" "strings" @@ -10,32 +11,31 @@ import ( // ActionsTable is the shared actions reference used by both the system prompt and starters. // Adding a new action here automatically makes it available to conversation starters. -const ActionsTable = `| Type | Params | auto_execute | -|---|---|---| -| get_market_price | asset, fiat (default "usd") | true | -| get_balances | chain (optional) | true | -| get_portfolio | fiat (default "usd") | true | -| add_chain | chain | true | -| add_coin | chain, ticker, contract_address (optional) | true | -| search_token | query, chain (optional) | true | -| remove_coin | chain, ticker | true | -| remove_chain | chain | true | -| build_send_tx | chain, symbol, address, amount, memo (optional) | true | -| build_custom_tx | tx_type, chain, symbol, amount, memo, contract_address, function_name, params, value, execute_msg, funds | true | -| build_tx | from_chain, from_symbol, from_contract, from_decimals, to_chain, to_symbol, to_contract, to_decimals, amount | true | -| sign_tx | (no params needed — app fills from stored tx) | true | -| list_vaults | | true | -| plugin_install | plugin_id | true | -| create_policy | plugin_id, configuration | true | -| delete_policy | policy_id | true | -| address_book_add | name, chain, address | true | -| address_book_remove | name, chain | true | -| read_evm_contract | chain, contract_address, function_name, params, output_types | true | -| scan_tx | chain, from, to, value, data | true | -| thorchain_query | query_type, asset (optional) | true |` - -// SystemPrompt is the base system prompt for the Vultisig AI assistant. -var SystemPrompt = `You are the Vultisig AI assistant inside the Vultisig mobile wallet app. +const ActionsTable = `| Type | Params | +|---|---| +| get_market_price | asset, fiat (default "usd") | +| get_balances | chain (optional) | +| get_portfolio | fiat (default "usd") | +| add_chain | chain | +| add_coin | chain, ticker, contract_address (optional) | +| search_token | query, chain (optional) | +| remove_coin | chain, ticker | +| remove_chain | chain | +| build_swap_tx | from_chain, from_symbol, from_contract, from_decimals, to_chain, to_symbol, to_contract, to_decimals, amount | +| build_send_tx | chain, symbol, address, amount, memo (optional) | +| build_custom_tx | tx_type, chain, symbol, amount, memo, contract_address, function_name, params, value, execute_msg, funds | +| sign_tx | (no params needed — app fills from stored tx) | +| list_vaults | | +| plugin_install | plugin_id | +| create_policy | plugin_id, configuration | +| delete_policy | policy_id | +| address_book_add | name, chain, address | +| address_book_remove | name, chain | +| read_evm_contract | chain, contract_address, function_name, params, output_types | +| scan_tx | chain, from, to, value, data | +| thorchain_query | query_type, asset (optional) |` + +const systemPromptPrefix = `You are the Vultisig AI assistant inside the Vultisig mobile wallet app. Vultisig is a self-custodial, seedless crypto wallet using Threshold Signature Scheme (TSS) — no seed phrases, multi-device signing (e.g. 2-of-3), one vault across many chains. @@ -45,30 +45,32 @@ Supported chains: Ethereum, Arbitrum, Avalanche, BNB Chain, Base, Blast, Optimis Include actions when the user's request requires the app to do something. You can return MULTIPLE actions in a single response — the app executes them all in parallel. Batch related actions together (e.g. get_market_price + get_balances, or multiple add_coin calls). -` + ActionsTable + ` +` -auto_execute=true actions run immediately in parallel. auto_execute=false actions render as tappable cards for user confirmation. +const systemPromptAfterActions = ` + +All actions are executed immediately by the app. Actions requiring signing will prompt the user for their password. ## Swap Transaction Building -For swaps, ALWAYS use build_tx. This builds the actual unsigned transaction so the user can review exact output amounts and sign directly. +For swaps, ALWAYS use build_swap_tx. This builds the actual unsigned transaction so the user can review exact output amounts and sign directly. -CRITICAL RULES for build_tx params: +CRITICAL RULES for build_swap_tx params: - from_contract and from_decimals: copy EXACTLY from "Coins in Vault" context. The "contract" field is from_contract, the "decimals" field is from_decimals. For native tokens (contract = "native"), set from_contract to empty string "". -- to_contract and to_decimals: copy EXACTLY from "Coins in Vault" if the destination token is there. If not in vault, omit both fields (server resolves automatically). +- to_contract and to_decimals: copy EXACTLY from "Coins in Vault" if the destination token is there. If not in vault, omit both fields (app resolves automatically). - NEVER fabricate or guess contract addresses. Only use values copied verbatim from context. - If the source token is NOT in "Coins in Vault", the user doesn't have it — tell them. - If the destination token is NOT in "Coins in Vault", omit to_contract and to_decimals. -- amount is in HUMAN-READABLE units (e.g. "10" for 10 USDC, NOT base units). The server handles conversion. +- amount is in HUMAN-READABLE units (e.g. "10" for 10 USDC, NOT base units). The app handles conversion. - When the user specifies a fiat/dollar amount (e.g., "$10 of ETH", "100 USD worth of BTC"), do NOT put the fiat number in the amount field. Instead: - 1. First call get_market_price for the source token (auto_execute=true) to get the current price. + 1. First call get_market_price for the source token to get the current price. 2. After receiving the price result, calculate: token_amount = fiat_amount / price. - 3. Compare the calculated token_amount against the user's balance in Balances context. If insufficient, tell the user (e.g. "You only have 0.899 ETH (~$1,786), which isn't enough for a $10,000 swap.") and do NOT call build_tx. - 4. If sufficient, call build_tx with the calculated token amount in the amount field. + 3. Compare the calculated token_amount against the user's balance in Balances context. If insufficient, tell the user (e.g. "You only have 0.899 ETH (~$1,786), which isn't enough for a $10,000 swap.") and do NOT call build_swap_tx. + 4. If sufficient, call build_swap_tx with the calculated token amount in the amount field. -## Transaction Confirmation (build_tx result handling) +## Swap Transaction Confirmation (build_swap_tx result handling) -When you receive a successful build_tx action result, you MUST respond with EXACTLY this template and NOTHING else. Copy it verbatim, only replacing the bracketed placeholders. +When you receive a successful build_swap_tx action result, you MUST respond with EXACTLY this template and NOTHING else. Copy it verbatim, only replacing the bracketed placeholders. TEMPLATE: "Swap [amount] [FROM] for ~[expected_output] [TO] via [provider][CROSS_CHAIN]. [APPROVAL_LINE]Ready to swap?" @@ -88,8 +90,8 @@ ABSOLUTE RULES: - Do NOT use exclamation marks, bullet points, contract addresses, gas limits, or technical details. - Do NOT explain what will happen. Just state the swap and ask "Ready to swap?" -When the user confirms (yes, confirm, do it, go, etc.) → return sign_tx action with auto_execute=true and empty params. -When the user wants changes → adjust and call build_tx again. +When the user confirms (yes, confirm, do it, go, etc.) → return sign_tx action with empty params. +When the user wants changes → adjust and call build_swap_tx again. When the user cancels → acknowledge briefly, no sign_tx. ## Send Transaction Building @@ -123,7 +125,7 @@ ABSOLUTE RULES: - Do NOT add any words before or after the template. - Do NOT use exclamation marks, bullet points, or technical details. -When the user confirms → return sign_tx action with auto_execute=true and empty params. +When the user confirms → return sign_tx action with empty params. When the user cancels → acknowledge briefly, no sign_tx. ## Custom Transaction Building @@ -166,7 +168,7 @@ TEMPLATE for deposit: "[ACTION] [amount] [SYMBOL] on [chain]. Ready to execute?" TEMPLATE for evm_contract: "Call [function_name] on [truncated_contract] on [chain]. Ready to execute?" TEMPLATE for wasm_execute: "Execute contract [truncated_contract] on [chain]. Ready to execute?" -When the user confirms → return sign_tx action with auto_execute=true and empty params. +When the user confirms → return sign_tx action with empty params. When the user cancels → acknowledge briefly, no sign_tx. ## THORChain Position Queries @@ -218,12 +220,12 @@ Never auto-add tokens from search results — always let the user review and con ## Guidelines -- Be extremely concise — 1-2 short sentences max. Users are on mobile. No narration, no restating what the user said, no "I'll set up..." preamble. Just act. +- CONCISENESS IS CRITICAL. Maximum 5 sentences or 5 bullet points per response. Users are on mobile. No narration, no restating what the user said, no preamble. Just act. - When you can fulfill the request, do it immediately with minimal commentary. Only explain if something is wrong (insufficient balance, ambiguity). - Fix obvious typos silently (e.g. "USDDC" → USDC). Don't ask for confirmation on obvious corrections. -- ALWAYS check the user's Balances context before calling build_tx or build_send_tx. Use the balance for the specific chain (e.g. "USDT on Ethereum"), NOT the sum across all chains. If insufficient, tell the user and do NOT build. Warn if source balance is under ~$5 (DEX minimums for swaps). +- ALWAYS check the user's Balances context before calling build_swap_tx or build_send_tx. Use the balance for the specific chain (e.g. "USDT on Ethereum"), NOT the sum across all chains. If insufficient, tell the user and do NOT build. Warn if source balance is under ~$5 (DEX minimums for swaps). - When a token exists on multiple chains, auto-select the chain where the user has sufficient balance. Only ask which chain if MULTIPLE chains have enough balance for the requested amount. -- Use auto_execute=true actions for price/balance/portfolio queries. +- Use actions for price/balance/portfolio queries. - Use suggestions for recurring automation (DCA, scheduled swaps) to guide plugin-based policy flow. - Only ask clarifying questions when genuinely ambiguous — not for typos or when a reasonable default exists. - Don't fabricate Vultisig-specific facts (tokenomics, roadmap, etc.) — suggest checking official channels. @@ -238,11 +240,86 @@ You have the user's address book in context. A contact may have multiple entries - Not found at all → tell the user the contact wasn't found, ask for the address For vault-to-vault sends ("send to my other vault"): -- Use list_vaults action (auto_execute=true) to look up vault addresses +- Use list_vaults action to look up vault addresses - After receiving vault data, determine the correct chain address and call build_send_tx If any required param (coin, address, amount) is missing, ask the user for it — do NOT call build_send_tx until all params are known.` +func buildSystemPrompt(actionsTable string) string { + return systemPromptPrefix + actionsTable + systemPromptAfterActions +} + +// BuildActionsTable generates a markdown actions table from app-declared tools. +func BuildActionsTable(tools []AppTool) string { + var sb strings.Builder + sb.WriteString("| Type | Params |\n|---|---|\n") + for _, t := range tools { + sb.WriteString("| ") + sb.WriteString(t.Name) + sb.WriteString(" | ") + sb.WriteString(t.Params) + sb.WriteString(" |\n") + } + return strings.TrimRight(sb.String(), "\n") +} + +// BuildRespondToUserTool creates a respond_to_user tool with action type enum +// derived from app-declared tools instead of the hardcoded list. +func BuildRespondToUserTool(tools []AppTool) ai.Tool { + names := make([]string, len(tools)) + for i, t := range tools { + names[i] = t.Name + } + + srcSchema, ok := RespondToUserTool.InputSchema.(map[string]any) + if !ok { + return RespondToUserTool + } + schema := copySchema(srcSchema) + + props, ok := schema["properties"].(map[string]any) + if !ok { + return RespondToUserTool + } + actions, ok := props["actions"].(map[string]any) + if !ok { + return RespondToUserTool + } + items, ok := actions["items"].(map[string]any) + if !ok { + return RespondToUserTool + } + itemProps, ok := items["properties"].(map[string]any) + if !ok { + return RespondToUserTool + } + typeProp, ok := itemProps["type"].(map[string]any) + if !ok { + return RespondToUserTool + } + + typeProp["enum"] = names + + return ai.Tool{ + Name: RespondToUserTool.Name, + Description: RespondToUserTool.Description, + InputSchema: schema, + } +} + +func copySchema(src map[string]any) map[string]any { + data, err := json.Marshal(src) + if err != nil { + return src + } + var dst map[string]any + err = json.Unmarshal(data, &dst) + if err != nil { + return src + } + return dst +} + // RespondToUserTool is the tool definition for responding to users. var RespondToUserTool = ai.Tool{ Name: "respond_to_user", @@ -296,7 +373,7 @@ var RespondToUserTool = ai.Tool{ "enum": []string{ "get_market_price", "get_balances", "get_portfolio", "add_chain", "add_coin", "search_token", "remove_coin", "remove_chain", - "build_send_tx", "build_custom_tx", "build_tx", "sign_tx", "scan_tx", + "build_swap_tx", "build_send_tx", "build_custom_tx", "sign_tx", "scan_tx", "read_evm_contract", "thorchain_query", "plugin_install", "create_policy", "delete_policy", @@ -318,10 +395,6 @@ var RespondToUserTool = ai.Tool{ "description": "Parameters for the action. Keys depend on the action type.", "additionalProperties": true, }, - "auto_execute": map[string]any{ - "type": "boolean", - "description": "If true, the app executes this action immediately without user interaction. Use for read-only actions (get_market_price, get_balances, get_portfolio). Default false.", - }, }, "required": []string{"type", "title"}, }, @@ -486,9 +559,13 @@ func writeWalletContext(sb *strings.Builder, msgCtx *MessageContext, opts wallet } // BuildFullPrompt constructs the complete system prompt with context and plugin skills. -func BuildFullPrompt(msgCtx *MessageContext, plugins []PluginSkill) string { +// If actionsTable is empty, the hardcoded ActionsTable is used. +func BuildFullPrompt(msgCtx *MessageContext, plugins []PluginSkill, actionsTable string) string { + if actionsTable == "" { + actionsTable = ActionsTable + } var sb strings.Builder - sb.WriteString(SystemPrompt) + sb.WriteString(buildSystemPrompt(actionsTable)) if len(plugins) > 0 { sb.WriteString("\n\n## Available Plugins\n\n") diff --git a/internal/service/agent/prompt_test.go b/internal/service/agent/prompt_test.go index 6f0ce62..1a41fa8 100644 --- a/internal/service/agent/prompt_test.go +++ b/internal/service/agent/prompt_test.go @@ -7,7 +7,7 @@ import ( func TestBuildFullPrompt(t *testing.T) { t.Run("nil context", func(t *testing.T) { - got := BuildFullPrompt(nil, nil) + got := BuildFullPrompt(nil, nil, "") if !strings.Contains(got, "Vultisig AI assistant") { t.Error("expected system prompt") } @@ -19,7 +19,7 @@ func TestBuildFullPrompt(t *testing.T) { {Chain: "Ethereum", Symbol: "ETH", Amount: "1.5"}, }, } - got := BuildFullPrompt(ctx, nil) + got := BuildFullPrompt(ctx, nil, "") if !strings.Contains(got, "ETH on Ethereum: 1.5") { t.Errorf("expected balance in prompt, got:\n%s", got) } @@ -29,7 +29,7 @@ func TestBuildFullPrompt(t *testing.T) { plugins := []PluginSkill{ {PluginID: "dca-plugin", Name: "DCA", Skills: "Dollar cost average into tokens"}, } - got := BuildFullPrompt(nil, plugins) + got := BuildFullPrompt(nil, plugins, "") if !strings.Contains(got, "DCA (dca-plugin)") { t.Errorf("expected plugin in prompt, got:\n%s", got) } @@ -41,7 +41,7 @@ func TestBuildFullPrompt(t *testing.T) { {Title: "Alice", Address: "0xabc", Chain: "Ethereum"}, }, } - got := BuildFullPrompt(ctx, nil) + got := BuildFullPrompt(ctx, nil, "") if !strings.Contains(got, "Alice: 0xabc (Ethereum)") { t.Errorf("expected address book entry in prompt") } @@ -49,6 +49,66 @@ func TestBuildFullPrompt(t *testing.T) { t.Errorf("expected address book hint in full prompt") } }) + + t.Run("with custom actions table", func(t *testing.T) { + tools := []AppTool{ + {Name: "get_balances", Params: "chain (optional)"}, + {Name: "build_swap_tx", Params: "from, to, amount"}, + } + table := BuildActionsTable(tools) + got := BuildFullPrompt(nil, nil, table) + if !strings.Contains(got, "get_balances") { + t.Error("expected custom action in prompt") + } + if !strings.Contains(got, "build_swap_tx") { + t.Error("expected custom action in prompt") + } + }) +} + +func TestBuildActionsTable(t *testing.T) { + tools := []AppTool{ + {Name: "get_balances", Params: "chain (optional)"}, + {Name: "sign_tx", Params: ""}, + } + got := BuildActionsTable(tools) + if !strings.Contains(got, "| get_balances | chain (optional) |") { + t.Errorf("expected get_balances row, got:\n%s", got) + } + if !strings.Contains(got, "| sign_tx | |") { + t.Errorf("expected sign_tx row, got:\n%s", got) + } +} + +func TestBuildRespondToUserTool(t *testing.T) { + tools := []AppTool{ + {Name: "get_balances"}, + {Name: "send_tx"}, + } + tool := BuildRespondToUserTool(tools) + if tool.Name != "respond_to_user" { + t.Errorf("expected respond_to_user, got %s", tool.Name) + } + + schema, ok := tool.InputSchema.(map[string]any) + if !ok { + t.Fatal("expected map schema") + } + props := schema["properties"].(map[string]any) + actions := props["actions"].(map[string]any) + items := actions["items"].(map[string]any) + itemProps := items["properties"].(map[string]any) + typeProp := itemProps["type"].(map[string]any) + enum, ok := typeProp["enum"].([]string) + if !ok { + t.Fatal("expected enum array") + } + if len(enum) != 2 { + t.Errorf("expected 2 enum values, got %d", len(enum)) + } + if enum[0] != "get_balances" || enum[1] != "send_tx" { + t.Errorf("unexpected enum values: %v", enum) + } } func TestBuildSystemPromptWithSummary(t *testing.T) { diff --git a/internal/service/agent/tools.go b/internal/service/agent/tools.go index 3db6bce..799ddd0 100644 --- a/internal/service/agent/tools.go +++ b/internal/service/agent/tools.go @@ -161,3 +161,19 @@ func agentTools() []ai.Tool { SetVaultTool, } } + +func (s *AgentService) buildAgentTools(appTools []AppTool) []ai.Tool { + if len(appTools) == 0 { + return agentTools() + } + + return []ai.Tool{ + BuildRespondToUserTool(appTools), + CheckPluginInstalledTool, + CheckBillingStatusTool, + GetRecipeSchemaTool, + SuggestPolicyTool, + CreateSuggestionTool, + SetVaultTool, + } +} diff --git a/internal/service/agent/types.go b/internal/service/agent/types.go index b7ed301..e863e63 100644 --- a/internal/service/agent/types.go +++ b/internal/service/agent/types.go @@ -1,17 +1,22 @@ package agent import ( - "encoding/json" - "github.com/vultisig/agent-backend/internal/types" ) +// AppTool declares a tool the app supports, sent on first message per conversation. +type AppTool struct { + Name string `json:"name"` + Params string `json:"params"` +} + // SendMessageRequest is the request body for sending a message. type SendMessageRequest struct { PublicKey string `json:"public_key"` Content string `json:"content"` Model string `json:"model,omitempty"` Context *MessageContext `json:"context,omitempty"` + Tools []AppTool `json:"tools,omitempty"` SelectedSuggestionID *string `json:"selected_suggestion_id,omitempty"` ActionResult *ActionResult `json:"action_result,omitempty"` AccessToken string `json:"-"` // Populated by API layer, not from JSON @@ -58,7 +63,6 @@ type Action struct { Title string `json:"title"` Description string `json:"description,omitempty"` Params map[string]any `json:"params,omitempty"` - AutoExecute bool `json:"auto_execute"` } // ActionResult contains the result of a user action. @@ -78,7 +82,6 @@ type SendMessageResponse struct { Actions []Action `json:"actions,omitempty"` PolicyReady *PolicyReady `json:"policy_ready,omitempty"` InstallRequired *InstallRequired `json:"install_required,omitempty"` - TxReady *TxReady `json:"tx_ready,omitempty"` Transactions []Transaction `json:"transactions,omitempty"` Tokens *TokenSearchResult `json:"tokens,omitempty"` } @@ -95,23 +98,6 @@ type Transaction struct { TxDetails map[string]string `json:"tx_details"` } -type TxReady struct { - Provider string `json:"provider"` - ExpectedOutput string `json:"expected_output"` - MinimumOutput string `json:"minimum_output"` - NeedsApproval bool `json:"needs_approval"` - ApprovalTx json.RawMessage `json:"approval_tx,omitempty"` - SwapTx json.RawMessage `json:"swap_tx"` - FromChain string `json:"from_chain"` - FromSymbol string `json:"from_symbol"` - FromDecimals int `json:"from_decimals"` - ToChain string `json:"to_chain"` - ToSymbol string `json:"to_symbol"` - ToDecimals int `json:"to_decimals"` - Amount string `json:"amount"` - Sender string `json:"sender"` - Destination string `json:"destination"` -} // InstallRequired signals that a plugin must be installed before proceeding. type InstallRequired struct { @@ -135,18 +121,6 @@ type Suggestion struct { Description string `json:"description"` } -type BuildTxRequest struct { - PublicKey string `json:"public_key"` - Params map[string]any `json:"params"` - Context *MessageContext `json:"context,omitempty"` -} - -type BuildTxResponse struct { - TxReady *TxReady `json:"tx_ready,omitempty"` - Actions []Action `json:"actions,omitempty"` - Message string `json:"message,omitempty"` - Error string `json:"error,omitempty"` -} type SSEEvent struct { Event string `json:"event"` @@ -208,7 +182,6 @@ type ToolAction struct { Title string `json:"title"` Description string `json:"description,omitempty"` Params map[string]any `json:"params,omitempty"` - AutoExecute bool `json:"auto_execute"` } // TokenSearchResult contains tokens returned by the find_token MCP tool. diff --git a/internal/service/mcp/swap_adapter.go b/internal/service/mcp/swap_adapter.go deleted file mode 100644 index 8d66436..0000000 --- a/internal/service/mcp/swap_adapter.go +++ /dev/null @@ -1,78 +0,0 @@ -package mcp - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/vultisig/agent-backend/internal/service/agent" -) - -type SwapTxAdapter struct { - client *Client -} - -func NewSwapTxAdapter(client *Client) *SwapTxAdapter { - return &SwapTxAdapter{client: client} -} - -func (a *SwapTxAdapter) BuildSwapTx(ctx context.Context, req agent.SwapTxBuildRequest) (*agent.SwapTxBuildResponse, error) { - args := map[string]any{ - "from_chain": req.FromChain, - "from_symbol": req.FromSymbol, - "to_chain": req.ToChain, - "to_symbol": req.ToSymbol, - "amount": req.Amount, - "sender": req.Sender, - "destination": req.Destination, - } - - if req.FromAddress != "" { - args["from_address"] = req.FromAddress - } - if req.FromDecimals != nil { - args["from_decimals"] = *req.FromDecimals - } - if req.ToAddress != "" { - args["to_address"] = req.ToAddress - } - if req.ToDecimals != nil { - args["to_decimals"] = *req.ToDecimals - } - - resultText, err := a.client.CallTool(ctx, "build_swap_tx", args) - if err != nil { - return nil, fmt.Errorf("MCP build_swap_tx: %w", err) - } - - var mcpResp mcpSwapResponse - err = json.Unmarshal([]byte(resultText), &mcpResp) - if err != nil { - return nil, fmt.Errorf("unmarshal MCP swap response: %w", err) - } - - resp := &agent.SwapTxBuildResponse{ - Provider: mcpResp.Provider, - ExpectedOutput: mcpResp.ExpectedOutput, - MinimumOutput: mcpResp.MinimumOutput, - NeedsApproval: mcpResp.NeedsApproval, - SwapTx: mcpResp.SwapTx, - Memo: mcpResp.Memo, - } - - if mcpResp.NeedsApproval && mcpResp.ApprovalTx != nil { - resp.ApprovalTx = mcpResp.ApprovalTx - } - - return resp, nil -} - -type mcpSwapResponse struct { - Provider string `json:"provider"` - ExpectedOutput string `json:"expected_output"` - MinimumOutput string `json:"minimum_output"` - NeedsApproval bool `json:"needs_approval"` - ApprovalTx json.RawMessage `json:"approval_tx,omitempty"` - SwapTx json.RawMessage `json:"swap_tx"` - Memo string `json:"memo,omitempty"` -}