diff --git a/cmd/neona/daemon.go b/cmd/neona/daemon.go index b06f2e0..534d58e 100644 --- a/cmd/neona/daemon.go +++ b/cmd/neona/daemon.go @@ -13,6 +13,7 @@ import ( "github.com/fentz26/neona/internal/audit" "github.com/fentz26/neona/internal/connectors/localexec" "github.com/fentz26/neona/internal/controlplane" + "github.com/fentz26/neona/internal/mcp" "github.com/fentz26/neona/internal/scheduler" "github.com/fentz26/neona/internal/store" "github.com/spf13/cobra" @@ -60,6 +61,21 @@ func runDaemon(cmd *cobra.Command, args []string) error { schedulerCfg := scheduler.DefaultConfig() sched := scheduler.New(s, pdr, connector, schedulerCfg) + // Initialize MCP router + mcpConfig, err := mcp.LoadConfigFromHome() + if err != nil { + log.Printf("Warning: failed to load MCP config: %v (using defaults)", err) + mcpConfig = mcp.DefaultConfig() + } + registry := mcp.NewRegistry() + registry.RegisterDefaults() + mcpRouter := mcp.NewRouter(mcpConfig, registry) + log.Printf("MCP router initialized with %d servers", registry.Count()) + + // Wire MCP router to scheduler and server + sched.SetMCPRouter(mcpRouter) + server.SetMCPRouter(mcpRouter) + // Wire scheduler to server for /workers endpoint server.SetScheduler(sched) diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index eef2308..966a301 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/fentz26/neona/internal/mcp" "github.com/fentz26/neona/internal/models" "github.com/fentz26/neona/internal/store" ) @@ -20,6 +21,11 @@ type SchedulerStatsProvider interface { GetStats() map[string]interface{} } +// MCPRouter provides MCP routing for the /mcp/route endpoint. +type MCPRouter interface { + Route(ctx context.Context, task mcp.Task) (*mcp.RoutingResult, error) +} + // Server provides the HTTP API for Neona. type Server struct { service *Service @@ -27,6 +33,7 @@ type Server struct { addr string server *http.Server scheduler SchedulerStatsProvider + mcpRouter MCPRouter } // NewServer creates a new HTTP server. @@ -44,6 +51,12 @@ func (s *Server) SetScheduler(sched SchedulerStatsProvider) { s.scheduler = sched } +// SetMCPRouter sets the MCP router for the /mcp/route endpoint. +// Must be called before Start() - not safe for concurrent use. +func (s *Server) SetMCPRouter(router MCPRouter) { + s.mcpRouter = router +} + // Start starts the HTTP server. func (s *Server) Start() error { mux := http.NewServeMux() @@ -58,6 +71,9 @@ func (s *Server) Start() error { // Worker pool monitor endpoint mux.HandleFunc("/workers", s.handleWorkers) + // MCP routing endpoint + mux.HandleFunc("/mcp/route", s.handleMCPRoute) + // Health check with DB ping mux.HandleFunc("/health", s.handleHealth) @@ -410,3 +426,81 @@ func (s *Server) handleWorkers(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(stats) } + +// --- MCP Route Handlers --- + +// mcpRouteRequest represents the request body for /mcp/route +type mcpRouteRequest struct { + Title string `json:"title"` + Description string `json:"description"` +} + +// mcpRouteResponse represents the response for /mcp/route +type mcpRouteResponse struct { + SelectedMCPs []mcpServerInfo `json:"selected_mcps"` + MatchedRules []string `json:"matched_rules"` + TotalTools int `json:"total_tools"` + ToolBudget int `json:"tool_budget"` +} + +type mcpServerInfo struct { + Name string `json:"name"` + ToolCount int `json:"tool_count"` +} + +// handleMCPRoute handles POST /mcp/route +func (s *Server) handleMCPRoute(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if s.mcpRouter == nil { + http.Error(w, "MCP router not configured", http.StatusServiceUnavailable) + return + } + + var req mcpRouteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + + if req.Title == "" { + http.Error(w, "title is required", http.StatusBadRequest) + return + } + + task := mcp.Task{ + Title: req.Title, + Description: req.Description, + } + + result, err := s.mcpRouter.Route(r.Context(), task) + if err != nil { + log.Printf("MCP routing failed: %v", err) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + + // Build response + mcps := make([]mcpServerInfo, len(result.SelectedMCPs)) + for i, m := range result.SelectedMCPs { + mcps[i] = mcpServerInfo{ + Name: m.Name, + ToolCount: m.ToolCount, + } + } + + resp := mcpRouteResponse{ + SelectedMCPs: mcps, + MatchedRules: result.MatchedRules, + TotalTools: result.TotalTools, + ToolBudget: 80, // Default budget + } +w.Header().Set("Content-Type", "application/json") +if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Failed to encode MCP route response: %v", err) +} + json.NewEncoder(w).Encode(resp) +} diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index da24e1a..a1122e2 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -10,6 +10,7 @@ import ( "github.com/fentz26/neona/internal/audit" "github.com/fentz26/neona/internal/connectors" + "github.com/fentz26/neona/internal/mcp" "github.com/fentz26/neona/internal/models" "github.com/fentz26/neona/internal/store" "github.com/google/uuid" @@ -33,6 +34,9 @@ type Scheduler struct { connector connectors.Connector config *Config + // MCP router for tool selection + mcpRouter *mcp.KeywordRouter + // Worker pool state mu sync.Mutex activeWorkers int @@ -69,6 +73,12 @@ func New(s *store.Store, pdr *audit.PDRWriter, conn connectors.Connector, cfg *C } } +// SetMCPRouter sets the MCP router for tool selection. +// Must be called before Start() - not safe for concurrent use. +func (sch *Scheduler) SetMCPRouter(router *mcp.KeywordRouter) { + sch.mcpRouter = router +} + // Start begins the scheduler loop. func (sch *Scheduler) Start() { sch.mu.Lock() @@ -149,6 +159,32 @@ func (sch *Scheduler) pollAndDispatch() { "connector": connectorName, }, "success", task.ID, fmt.Sprintf("Dispatched to worker %s", workerID)) + // Route MCPs for this task if router is configured + if sch.mcpRouter != nil { + mcpTask := mcp.Task{ + ID: task.ID, + Title: task.Title, + Description: task.Description, + } + result, err := sch.mcpRouter.Route(sch.ctx, mcpTask) + if err != nil { + log.Printf("MCP routing error for task %s: %v", task.ID, err) + } else { + // Log selected MCPs + mcpNames := make([]string, len(result.SelectedMCPs)) + for i, m := range result.SelectedMCPs { + mcpNames[i] = m.Name + } + sch.pdr.Record("task.mcp_route", map[string]interface{}{ + "task_id": task.ID, + "selected_mcps": mcpNames, + "total_tools": result.TotalTools, + "matched_rules": result.MatchedRules, + }, "success", task.ID, fmt.Sprintf("Routed to %d MCPs with %d tools", len(mcpNames), result.TotalTools)) + log.Printf("Task %s routed to MCPs: %v (%d tools)", task.ID, mcpNames, result.TotalTools) + } + } + log.Printf("Dispatched task %s (%s) to worker %s", task.ID, task.Title, workerID) // Increment worker counts and store worker info