From ee35ec487da5d33bbc6896cb7ec81fee2f2f255f Mon Sep 17 00:00:00 2001 From: bilalbayram Date: Sat, 21 Mar 2026 13:41:45 +0300 Subject: [PATCH 1/4] feat: add managed router daemon for OpenWrt --- .github/workflows/release.yml | 12 +- Makefile | 5 +- README.md | 16 + cmd/opensnitch-web/main.go | 35 +- cmd/opensnitchd-router/firewall.go | 190 ++++ cmd/opensnitchd-router/grpc.go | 137 +++ cmd/opensnitchd-router/main.go | 167 ++++ cmd/opensnitchd-router/monitor_forward.go | 176 ++++ cmd/opensnitchd-router/monitor_local.go | 364 +++++++ cmd/opensnitchd-router/rules.go | 520 ++++++++++ cmd/opensnitchd-router/stats.go | 128 +++ config.yaml.example | 1 + internal/api/handlers_dns.go | 24 + internal/api/handlers_nodes.go | 33 + internal/api/handlers_prompt.go | 77 +- internal/api/handlers_routers.go | 215 ++++- internal/api/handlers_routers_test.go | 12 + internal/api/handlers_rules.go | 12 + internal/api/handlers_templates.go | 53 +- internal/api/router.go | 8 +- internal/api/router_managed.go | 62 ++ internal/api/router_managed_test.go | 225 +++++ internal/api/template_router_managed.go | 133 +++ internal/config/config.go | 14 +- internal/db/routers.go | 98 +- internal/db/sqlite.go | 6 +- internal/grpcserver/identity.go | 68 ++ internal/grpcserver/server.go | 15 + internal/grpcserver/service.go | 111 ++- internal/grpcserver/service_test.go | 80 ++ internal/prompter/prompter.go | 6 + internal/router/capabilities.go | 143 +++ internal/router/daemon_initd.sh | 11 + internal/router/daemon_templates.go | 30 + internal/router/provisioner.go | 410 +++++++- internal/router/provisioner_test.go | 10 +- internal/rules/router_managed.go | 93 ++ internal/templatesync/service.go | 15 + web/src/components/prompt/dialog.tsx | 8 +- web/src/components/rule-editor-sheet.tsx | 60 +- web/src/lib/api.ts | 74 ++ web/src/lib/rule-helpers.ts | 15 + web/src/pages/connections.tsx | 14 +- web/src/pages/nodes.tsx | 1057 +++++++++++++-------- web/src/pages/rules.tsx | 20 + web/src/stores/app-store.ts | 1 + 46 files changed, 4421 insertions(+), 543 deletions(-) create mode 100644 cmd/opensnitchd-router/firewall.go create mode 100644 cmd/opensnitchd-router/grpc.go create mode 100644 cmd/opensnitchd-router/main.go create mode 100644 cmd/opensnitchd-router/monitor_forward.go create mode 100644 cmd/opensnitchd-router/monitor_local.go create mode 100644 cmd/opensnitchd-router/rules.go create mode 100644 cmd/opensnitchd-router/stats.go create mode 100644 internal/api/router_managed.go create mode 100644 internal/api/router_managed_test.go create mode 100644 internal/api/template_router_managed.go create mode 100644 internal/grpcserver/identity.go create mode 100644 internal/router/capabilities.go create mode 100644 internal/router/daemon_initd.sh create mode 100644 internal/router/daemon_templates.go create mode 100644 internal/rules/router_managed.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 646169a..5218f65 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -81,11 +81,16 @@ jobs: BUILD_TIME=$(date -u '+%Y-%m-%dT%H:%M:%SZ') LDFLAGS="-X github.com/bilalbayram/opensnitch-web/internal/version.Version=${VERSION} -X github.com/bilalbayram/opensnitch-web/internal/version.BuildTime=${BUILD_TIME}" go build -ldflags "${LDFLAGS}" -o opensnitch-web-${{ matrix.goos }}-${{ matrix.goarch }} ./cmd/opensnitch-web + if [ "${{ matrix.goos }}" = "linux" ] && [ "${{ matrix.goarch }}" = "arm64" ]; then + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags "${LDFLAGS}" -o opensnitchd-router-linux-arm64 ./cmd/opensnitchd-router + fi - uses: actions/upload-artifact@v4 with: name: binary-${{ matrix.goos }}-${{ matrix.goarch }} - path: opensnitch-web-${{ matrix.goos }}-${{ matrix.goarch }} + path: | + opensnitch-web-${{ matrix.goos }}-${{ matrix.goarch }} + opensnitchd-router-* release: name: Create Release @@ -102,14 +107,15 @@ jobs: merge-multiple: true - name: Generate checksums - run: sha256sum opensnitch-web-* > checksums.txt + run: sha256sum opensnitch-web-* opensnitchd-router-* > checksums.txt - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: - prerelease: ${{ contains(github.ref_name, '-beta') || contains(github.ref_name, '-alpha') || contains(github.ref_name, '-rc') }} + prerelease: ${{ contains(github.ref_name, 'beta') || contains(github.ref_name, 'alpha') || contains(github.ref_name, 'rc') }} generate_release_notes: true files: | opensnitch-web-* + opensnitchd-router-* checksums.txt deploy/opensnitch-web.service diff --git a/Makefile b/Makefile index b4d519c..7aa508b 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build proto frontend clean run dev embed verify-embed install uninstall +.PHONY: all build proto frontend clean run dev embed verify-embed install uninstall daemon-router-arm64 VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") BUILD_TIME ?= $(shell date -u '+%Y-%m-%dT%H:%M:%SZ') @@ -30,6 +30,9 @@ verify-embed: embed build: embed CGO_ENABLED=1 go build -ldflags '$(LDFLAGS)' -o bin/opensnitch-web ./cmd/opensnitch-web +daemon-router-arm64: + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -ldflags '$(LDFLAGS)' -o bin/opensnitchd-router-linux-arm64 ./cmd/opensnitchd-router + # Run the server (dev mode — serves from web/dist) run: CGO_ENABLED=1 go run -ldflags '$(LDFLAGS)' ./cmd/opensnitch-web diff --git a/README.md b/README.md index 6c4535b..1caa6c1 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,22 @@ Replace `192.168.1.0/24` with your router subnet and `8080` with your configured The provisioner performs a connectivity check after deployment and will show a warning in the UI if the router cannot reach the server. +## Managed router daemon v1 + +OpenWrt routers can also be upgraded from the legacy HTTP ingest agent to the managed `opensnitchd-router` runtime. v1 is intentionally limited: + +- Router-local processes use the normal gRPC `AskRule` prompt flow. +- Forwarded LAN traffic is observed continuously and can be enforced with explicit rules, but unknown forwarded flows are allowed by default. +- Forwarded traffic never opens live prompts in v1. Use generated or manually created device-scoped rules instead. +- Only `aarch64` OpenWrt targets are supported in v1. + +The managed runtime connects to gRPC with the existing router API key in the `x-router-api-key` metadata header. If the server cannot infer a LAN-reachable gRPC endpoint automatically, set `server.grpc_public_addr` before upgrading routers. + +```bash +# Build the OpenWrt router daemon artifact +make daemon-router-arm64 +``` + ## Development Built with Go 1.22 (Chi, gRPC, SQLite) and React 19 (Vite, TypeScript, Tailwind CSS 4). diff --git a/cmd/opensnitch-web/main.go b/cmd/opensnitch-web/main.go index ca0c1c6..0eb0313 100644 --- a/cmd/opensnitch-web/main.go +++ b/cmd/opensnitch-web/main.go @@ -55,22 +55,27 @@ func main() { // Wire up prompter → WebSocket broadcasts p.OnNewPrompt = func(prompt *prompter.PendingPrompt) { conn := prompt.Connection + routerManaged := false + if router, err := database.GetRouterByLinkedNodeAddr(prompt.NodeAddr); err == nil { + routerManaged = router.DaemonMode == db.RouterDaemonModeRouterDaemon + } hub.BroadcastEvent(ws.EventPromptRequest, map[string]interface{}{ - "id": prompt.ID, - "node_addr": prompt.NodeAddr, - "created_at": prompt.CreatedAt.Format("2006-01-02 15:04:05"), - "process": conn.GetProcessPath(), - "dst_host": conn.GetDstHost(), - "dst_ip": conn.GetDstIp(), - "dst_port": conn.GetDstPort(), - "protocol": conn.GetProtocol(), - "src_ip": conn.GetSrcIp(), - "src_port": conn.GetSrcPort(), - "uid": conn.GetUserId(), - "pid": conn.GetProcessId(), - "args": conn.GetProcessArgs(), - "cwd": conn.GetProcessCwd(), - "checksums": conn.GetProcessChecksums(), + "id": prompt.ID, + "node_addr": prompt.NodeAddr, + "created_at": prompt.CreatedAt.Format("2006-01-02 15:04:05"), + "router_managed": routerManaged, + "process": conn.GetProcessPath(), + "dst_host": conn.GetDstHost(), + "dst_ip": conn.GetDstIp(), + "dst_port": conn.GetDstPort(), + "protocol": conn.GetProtocol(), + "src_ip": conn.GetSrcIp(), + "src_port": conn.GetSrcPort(), + "uid": conn.GetUserId(), + "pid": conn.GetProcessId(), + "args": conn.GetProcessArgs(), + "cwd": conn.GetProcessCwd(), + "checksums": conn.GetProcessChecksums(), }) } diff --git a/cmd/opensnitchd-router/firewall.go b/cmd/opensnitchd-router/firewall.go new file mode 100644 index 0000000..b834fae --- /dev/null +++ b/cmd/opensnitchd-router/firewall.go @@ -0,0 +1,190 @@ +package main + +import ( + "bytes" + "fmt" + "os/exec" + "strconv" + "strings" +) + +const ( + nftTable = "opensnitch-router" + nftOutputChain = "output" + nftForwardChain = "forward" +) + +func (d *daemon) ensureFirewall() error { + if err := d.nftRun("add", "table", "inet", nftTable); err != nil && !nftAlreadyExists(err) { + return err + } + if err := d.nftRun("add", "chain", "inet", nftTable, nftOutputChain, "{", "type", "filter", "hook", "output", "priority", "0;", "policy", "accept;", "}"); err != nil && !nftAlreadyExists(err) { + return err + } + if err := d.nftRun("add", "chain", "inet", nftTable, nftForwardChain, "{", "type", "filter", "hook", "forward", "priority", "0;", "policy", "accept;", "}"); err != nil && !nftAlreadyExists(err) { + return err + } + return nil +} + +func (d *daemon) disableFirewall() error { + err := d.nftRun("delete", "table", "inet", nftTable) + if err != nil && strings.Contains(err.Error(), "No such file or directory") { + return nil + } + return err +} + +func (d *daemon) reloadFirewallState() error { + if err := d.ensureFirewall(); err != nil { + return err + } + if err := d.flushChain(nftOutputChain); err != nil { + return err + } + return d.rebuildForwardRules() +} + +func (d *daemon) rebuildForwardRules() error { + if !d.isFirewallEnabled() { + return nil + } + if err := d.ensureFirewall(); err != nil { + return err + } + if err := d.flushChain(nftForwardChain); err != nil { + return err + } + + for _, rule := range d.snapshotRules() { + spec := compileForwardRule(rule) + if spec == nil { + continue + } + if err := d.installForwardRule(spec); err != nil { + return err + } + } + return nil +} + +func (d *daemon) installLocalDecision(flow *localFlow, action string) error { + if !d.isFirewallEnabled() { + return nil + } + if err := d.ensureFirewall(); err != nil { + return err + } + + args := []string{"add", "rule", "inet", nftTable, nftOutputChain} + if flow.SrcIP != "" { + args = append(args, sourceAddressMatch(flow.SrcIP)...) + } + if flow.DstIP != "" { + args = append(args, nftAddressMatch(flow.DstIP)...) + } + if flow.Protocol != "" { + args = append(args, protocolMatch(flow.Protocol)...) + } + if flow.SrcPort > 0 { + args = append(args, portMatch(flow.Protocol, "sport", int(flow.SrcPort))...) + } + if flow.DstPort > 0 { + args = append(args, portMatch(flow.Protocol, "dport", int(flow.DstPort))...) + } + args = append(args, nftVerdict(action)) + return d.nftRun(args...) +} + +func (d *daemon) installForwardRule(spec *forwardRuleSpec) error { + args := []string{"add", "rule", "inet", nftTable, nftForwardChain} + args = append(args, sourceAddressMatch(spec.SourceIP)...) + if spec.DestIP != "" { + args = append(args, nftAddressMatch(spec.DestIP)...) + } + if spec.Protocol != "" { + args = append(args, protocolMatch(spec.Protocol)...) + } + if spec.Port != "" { + args = append(args, portMatch(spec.Protocol, "dport", mustAtoi(spec.Port))...) + } + args = append(args, nftVerdict(spec.Action)) + return d.nftRun(args...) +} + +func protocolMatch(protocol string) []string { + switch strings.ToLower(strings.TrimSpace(protocol)) { + case "tcp", "tcp6": + return []string{"meta", "l4proto", "tcp"} + case "udp", "udp6": + return []string{"meta", "l4proto", "udp"} + case "icmp", "icmp6": + return []string{"meta", "l4proto", "icmp"} + default: + return nil + } +} + +func sourceAddressMatch(ip string) []string { + if strings.Contains(ip, ":") { + return []string{"ip6", "saddr", ip} + } + return []string{"ip", "saddr", ip} +} + +func nftAddressMatch(ip string) []string { + if strings.Contains(ip, ":") { + return []string{"ip6", "daddr", ip} + } + return []string{"ip", "daddr", ip} +} + +func portMatch(protocol, side string, port int) []string { + if port <= 0 { + return nil + } + switch strings.ToLower(strings.TrimSpace(protocol)) { + case "udp", "udp6": + return []string{"udp", side, fmt.Sprintf("%d", port)} + default: + return []string{"tcp", side, fmt.Sprintf("%d", port)} + } +} + +func nftVerdict(action string) string { + switch strings.ToLower(strings.TrimSpace(action)) { + case "deny": + return "drop" + case "reject": + return "reject" + default: + return "accept" + } +} + +func (d *daemon) flushChain(chain string) error { + err := d.nftRun("flush", "chain", "inet", nftTable, chain) + if err != nil && strings.Contains(err.Error(), "No such file or directory") { + return nil + } + return err +} + +func (d *daemon) nftRun(args ...string) error { + cmd := exec.Command("nft", args...) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("nft %s: %w (%s)", strings.Join(args, " "), err, strings.TrimSpace(stderr.String())) + } + return nil +} + +func nftAlreadyExists(err error) bool { + return err != nil && strings.Contains(err.Error(), "File exists") +} + +func mustAtoi(value string) int { + n, _ := strconv.Atoi(strings.TrimSpace(value)) + return n +} diff --git a/cmd/opensnitchd-router/grpc.go b/cmd/opensnitchd-router/grpc.go new file mode 100644 index 0000000..f953fb7 --- /dev/null +++ b/cmd/opensnitchd-router/grpc.go @@ -0,0 +1,137 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + pb "github.com/bilalbayram/opensnitch-web/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" +) + +type routerAPICredentials struct { + apiKey string +} + +func (c routerAPICredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return map[string]string{"x-router-api-key": c.apiKey}, nil +} + +func (c routerAPICredentials) RequireTransportSecurity() bool { + return false +} + +var _ credentials.PerRPCCredentials = (*routerAPICredentials)(nil) + +func (d *daemon) connect(ctx context.Context) error { + conn, err := grpc.DialContext( + ctx, + d.cfg.GRPCAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithPerRPCCredentials(routerAPICredentials{apiKey: d.cfg.APIKey}), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 10 * time.Second, + Timeout: 20 * time.Second, + PermitWithoutStream: true, + }), + ) + if err != nil { + return fmt.Errorf("dial grpc: %w", err) + } + + d.conn = conn + d.client = pb.NewUIClient(conn) + + if err := d.subscribe(ctx); err != nil { + return err + } + + stream, err := d.client.Notifications(ctx) + if err != nil { + return fmt.Errorf("open notifications stream: %w", err) + } + d.notif = stream + d.logger.Printf("connected to %s as %s", d.cfg.GRPCAddr, d.cfg.NodeName) + return nil +} + +func (d *daemon) subscribe(ctx context.Context) error { + _, err := d.client.Subscribe(ctx, &pb.ClientConfig{ + Name: d.cfg.NodeName, + Version: daemonVersion(), + IsFirewallRunning: d.isFirewallEnabled(), + Config: d.configJSON, + Rules: d.snapshotRules(), + }) + if err != nil { + return fmt.Errorf("subscribe: %w", err) + } + return nil +} + +func (d *daemon) askRule(ctx context.Context, flow *localFlow) (*pb.Rule, error) { + conn := flow.toProto() + rule, err := d.client.AskRule(ctx, conn) + if err != nil { + return nil, err + } + if rule == nil { + return nil, fmt.Errorf("AskRule returned no rule") + } + if strings.TrimSpace(rule.GetAction()) == "" { + return nil, fmt.Errorf("AskRule returned rule without action") + } + return rule, nil +} + +func (d *daemon) runPingLoop(ctx context.Context) error { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + var seq uint64 = 1 + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + stats := d.stats.snapshot(len(d.snapshotRules())) + if _, err := d.client.Ping(ctx, &pb.PingRequest{ + Id: seq, + Stats: stats, + }); err != nil { + return fmt.Errorf("ping: %w", err) + } + seq++ + } + } +} + +func (d *daemon) runNotifications(ctx context.Context) error { + for { + notif, err := d.notif.Recv() + if err != nil { + return fmt.Errorf("notifications recv: %w", err) + } + + reply := &pb.NotificationReply{Id: notif.GetId(), Code: pb.NotificationReplyCode_OK} + if err := d.handleNotification(notif); err != nil { + reply.Code = pb.NotificationReplyCode_ERROR + reply.Data = err.Error() + d.logger.Printf("notification %d failed: %v", notif.GetId(), err) + } + + if err := d.notif.Send(reply); err != nil { + return fmt.Errorf("notifications send reply: %w", err) + } + + select { + case <-ctx.Done(): + return nil + default: + } + } +} diff --git a/cmd/opensnitchd-router/main.go b/cmd/opensnitchd-router/main.go new file mode 100644 index 0000000..a3c2e76 --- /dev/null +++ b/cmd/opensnitchd-router/main.go @@ -0,0 +1,167 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/bilalbayram/opensnitch-web/internal/router" + "github.com/bilalbayram/opensnitch-web/internal/version" + pb "github.com/bilalbayram/opensnitch-web/proto" + "google.golang.org/grpc" +) + +const defaultConfigPath = "/etc/opensnitchd-router/config.json" + +// opensnitchd-router is the managed OpenWrt router runtime for v1. +// It only runs live AskRule prompts for router-local processes. Forwarded LAN +// flows are observed continuously and only enforced when an explicit device +// rule already exists. +type daemon struct { + cfg router.DaemonConfig + configJSON string + logger *log.Logger + + pollInterval time.Duration + rulesPath string + defaultAction string + + conn *grpc.ClientConn + client pb.UIClient + notif pb.UI_NotificationsClient + notifClose sync.Once + + stateMu sync.RWMutex + interceptionEnabled bool + firewallEnabled bool + + rulesMu sync.RWMutex + rules map[string]*ruleEntry + + stats *statsCollector +} + +func main() { + configPath := flag.String("config", defaultConfigPath, "Path to router-daemon config") + flag.Parse() + + cfg, rawConfig, err := loadConfig(*configPath) + if err != nil { + log.Fatalf("[router-daemon] load config: %v", err) + } + + logger := log.New(os.Stdout, "[router-daemon] ", log.LstdFlags|log.Lmsgprefix) + d := &daemon{ + cfg: cfg, + configJSON: rawConfig, + logger: logger, + pollInterval: time.Duration(cfg.PollIntervalMS) * time.Millisecond, + rulesPath: filepath.Join(filepath.Dir(*configPath), "rules.json"), + defaultAction: strings.ToLower(strings.TrimSpace(cfg.DefaultAction)), + interceptionEnabled: true, + firewallEnabled: true, + rules: make(map[string]*ruleEntry), + stats: newStatsCollector(), + } + + if err := d.loadRules(); err != nil { + logger.Fatalf("load cached rules: %v", err) + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := d.run(ctx); err != nil && ctx.Err() == nil { + logger.Fatalf("runtime failed: %v", err) + } +} + +func loadConfig(path string) (router.DaemonConfig, string, error) { + data, err := os.ReadFile(filepath.Clean(path)) + if err != nil { + return router.DaemonConfig{}, "", err + } + + var cfg router.DaemonConfig + if err := json.Unmarshal(data, &cfg); err != nil { + return router.DaemonConfig{}, "", err + } + + if strings.TrimSpace(cfg.GRPCAddr) == "" { + return router.DaemonConfig{}, "", fmt.Errorf("grpc_addr is required") + } + if strings.TrimSpace(cfg.APIKey) == "" { + return router.DaemonConfig{}, "", fmt.Errorf("api_key is required") + } + if strings.TrimSpace(cfg.NodeName) == "" { + return router.DaemonConfig{}, "", fmt.Errorf("node_name is required") + } + if strings.TrimSpace(cfg.DefaultAction) == "" { + cfg.DefaultAction = "deny" + } + if cfg.PollIntervalMS <= 0 { + cfg.PollIntervalMS = 1000 + } + if strings.TrimSpace(cfg.FirewallBackend) == "" { + cfg.FirewallBackend = "nft" + } + if strings.TrimSpace(cfg.FirewallBackend) != "nft" { + return router.DaemonConfig{}, "", fmt.Errorf("firewall_backend %q is not supported in v1", cfg.FirewallBackend) + } + + return cfg, string(data), nil +} + +func (d *daemon) run(ctx context.Context) error { + if err := d.connect(ctx); err != nil { + return err + } + defer d.close() + + if err := d.ensureFirewall(); err != nil { + return err + } + if err := d.rebuildForwardRules(); err != nil { + return err + } + + errCh := make(chan error, 2) + go func() { errCh <- d.runPingLoop(ctx) }() + go func() { errCh <- d.runNotifications(ctx) }() + go d.runLocalMonitor(ctx) + go d.runForwardMonitor(ctx) + + select { + case <-ctx.Done(): + return nil + case err := <-errCh: + return err + } +} + +func (d *daemon) close() { + d.notifClose.Do(func() { + if d.notif != nil { + _ = d.notif.CloseSend() + } + }) + if d.conn != nil { + _ = d.conn.Close() + } +} + +func daemonVersion() string { + if strings.TrimSpace(version.Version) == "" { + return "opensnitchd-router" + } + return "opensnitchd-router/" + version.Version +} diff --git a/cmd/opensnitchd-router/monitor_forward.go b/cmd/opensnitchd-router/monitor_forward.go new file mode 100644 index 0000000..288acff --- /dev/null +++ b/cmd/opensnitchd-router/monitor_forward.go @@ -0,0 +1,176 @@ +package main + +import ( + "bufio" + "context" + "net" + "os/exec" + "strconv" + "strings" + "time" + + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +type forwardFlow struct { + Protocol string + SrcIP string + SrcPort uint32 + DstIP string + DstPort uint32 +} + +func (f *forwardFlow) key() string { + return strings.Join([]string{ + f.Protocol, + f.SrcIP, + strconv.Itoa(int(f.SrcPort)), + f.DstIP, + strconv.Itoa(int(f.DstPort)), + }, "|") +} + +func (f *forwardFlow) toProto() *pb.Connection { + return &pb.Connection{ + Protocol: f.Protocol, + SrcIp: f.SrcIP, + SrcPort: f.SrcPort, + DstIp: f.DstIP, + DstPort: f.DstPort, + UserId: 0, + ProcessId: 0, + ProcessPath: "device:" + f.SrcIP, + } +} + +func (d *daemon) runForwardMonitor(ctx context.Context) { + ticker := time.NewTicker(d.pollInterval) + defer ticker.Stop() + + seen := make(map[string]time.Time) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if !d.isInterceptionEnabled() { + continue + } + flows, err := discoverForwardFlows() + if err != nil { + d.logger.Printf("discover forward flows: %v", err) + continue + } + + localIPs := localIPSet() + now := time.Now() + live := make(map[string]struct{}, len(flows)) + for _, flow := range flows { + if flow == nil { + continue + } + if _, ok := localIPs[flow.SrcIP]; ok { + continue + } + + key := flow.key() + live[key] = struct{}{} + if _, ok := seen[key]; ok { + continue + } + seen[key] = now + d.handleForwardFlow(flow) + } + + for key, lastSeen := range seen { + if _, ok := live[key]; ok { + continue + } + if now.Sub(lastSeen) > 2*d.pollInterval { + delete(seen, key) + } + } + } + } +} + +func (d *daemon) handleForwardFlow(flow *forwardFlow) { + rule, matched := d.evaluateForwardFlow(flow) + if !matched { + rule = &pb.Rule{ + Action: "allow", + Duration: "once", + Enabled: true, + } + } + d.stats.record(flow.toProto(), rule, matched) +} + +func discoverForwardFlows() ([]*forwardFlow, error) { + cmd := exec.Command("conntrack", "-L") + out, err := cmd.Output() + if err != nil { + return nil, err + } + + flows := make([]*forwardFlow, 0) + scanner := bufio.NewScanner(strings.NewReader(string(out))) + for scanner.Scan() { + if flow := parseConntrackLine(scanner.Text()); flow != nil { + flows = append(flows, flow) + } + } + return flows, scanner.Err() +} + +func parseConntrackLine(line string) *forwardFlow { + fields := strings.Fields(strings.TrimSpace(line)) + if len(fields) < 5 { + return nil + } + + protocol := strings.ToLower(strings.TrimSpace(fields[0])) + flow := &forwardFlow{Protocol: protocol} + + for _, field := range fields { + switch { + case strings.HasPrefix(field, "src=") && flow.SrcIP == "": + flow.SrcIP = strings.TrimPrefix(field, "src=") + case strings.HasPrefix(field, "dst=") && flow.DstIP == "": + flow.DstIP = strings.TrimPrefix(field, "dst=") + case strings.HasPrefix(field, "sport=") && flow.SrcPort == 0: + if value, err := strconv.ParseUint(strings.TrimPrefix(field, "sport="), 10, 32); err == nil { + flow.SrcPort = uint32(value) + } + case strings.HasPrefix(field, "dport=") && flow.DstPort == 0: + if value, err := strconv.ParseUint(strings.TrimPrefix(field, "dport="), 10, 32); err == nil { + flow.DstPort = uint32(value) + } + } + if flow.SrcIP != "" && flow.DstIP != "" && flow.DstPort > 0 { + break + } + } + + if flow.SrcIP == "" || flow.DstIP == "" || flow.DstPort == 0 { + return nil + } + return flow +} + +func localIPSet() map[string]struct{} { + result := make(map[string]struct{}) + addrs, err := net.InterfaceAddrs() + if err != nil { + return result + } + for _, addr := range addrs { + switch value := addr.(type) { + case *net.IPNet: + result[value.IP.String()] = struct{}{} + case *net.IPAddr: + result[value.IP.String()] = struct{}{} + } + } + return result +} diff --git a/cmd/opensnitchd-router/monitor_local.go b/cmd/opensnitchd-router/monitor_local.go new file mode 100644 index 0000000..27d9a6b --- /dev/null +++ b/cmd/opensnitchd-router/monitor_local.go @@ -0,0 +1,364 @@ +package main + +import ( + "context" + "encoding/hex" + "fmt" + "net" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "time" + + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +type localFlow struct { + Protocol string + SrcIP string + SrcPort uint32 + DstIP string + DstPort uint32 + UID uint32 + PID int + ProcessPath string + ProcessCwd string + ProcessArgs []string + Inode string +} + +func (f *localFlow) key() string { + return strings.Join([]string{ + f.Protocol, + f.SrcIP, + strconv.Itoa(int(f.SrcPort)), + f.DstIP, + strconv.Itoa(int(f.DstPort)), + strconv.Itoa(f.PID), + f.Inode, + }, "|") +} + +func (f *localFlow) toProto() *pb.Connection { + return &pb.Connection{ + Protocol: f.Protocol, + SrcIp: f.SrcIP, + SrcPort: f.SrcPort, + DstIp: f.DstIP, + DstPort: f.DstPort, + UserId: f.UID, + ProcessId: uint32(f.PID), + ProcessPath: f.ProcessPath, + ProcessCwd: f.ProcessCwd, + ProcessArgs: f.ProcessArgs, + } +} + +type procSocket struct { + Protocol string + SrcIP string + SrcPort uint32 + DstIP string + DstPort uint32 + UID uint32 + Inode string +} + +type procMeta struct { + PID int + ProcessPath string + ProcessCwd string + ProcessArgs []string +} + +func (d *daemon) runLocalMonitor(ctx context.Context) { + ticker := time.NewTicker(d.pollInterval) + defer ticker.Stop() + + seen := make(map[string]time.Time) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if !d.isInterceptionEnabled() { + continue + } + flows, err := discoverLocalFlows() + if err != nil { + d.logger.Printf("discover local flows: %v", err) + continue + } + + now := time.Now() + live := make(map[string]struct{}, len(flows)) + for _, flow := range flows { + if flow == nil || flow.PID == os.Getpid() || strings.Contains(flow.ProcessPath, "opensnitchd-router") { + continue + } + key := flow.key() + live[key] = struct{}{} + if _, ok := seen[key]; ok { + continue + } + seen[key] = now + + flowCopy := *flow + go d.handleLocalFlow(ctx, &flowCopy) + } + + for key, lastSeen := range seen { + if _, ok := live[key]; ok { + continue + } + if now.Sub(lastSeen) > 2*d.pollInterval { + delete(seen, key) + } + } + } + } +} + +func (d *daemon) handleLocalFlow(ctx context.Context, flow *localFlow) { + rule, matched, err := d.evaluateLocalFlow(flow) + if err != nil { + d.logger.Printf("evaluate local flow %s: %v", flow.key(), err) + return + } + + if !matched { + rule, err = d.askRule(ctx, flow) + if err != nil { + d.logger.Printf("AskRule failed for %s: %v", flow.key(), err) + rule = &pb.Rule{ + Name: fmt.Sprintf("fallback-%d", time.Now().UnixNano()), + Action: d.defaultAction, + Duration: "once", + Enabled: true, + Operator: &pb.Operator{ + Type: "simple", + Operand: "process.path", + Data: flow.ProcessPath, + }, + } + } else if !strings.EqualFold(rule.GetDuration(), "once") { + if err := d.upsertRule(rule, ruleTimestamp(rule)); err != nil { + d.logger.Printf("cache AskRule decision %s: %v", rule.GetName(), err) + } + } + } + + if err := d.installLocalDecision(flow, rule.GetAction()); err != nil { + d.logger.Printf("install local decision %s: %v", flow.key(), err) + } + d.stats.record(flow.toProto(), rule, matched) +} + +func discoverLocalFlows() ([]*localFlow, error) { + sockets := make(map[string]*procSocket) + for _, spec := range []struct { + path string + protocol string + }{ + {path: "/proc/net/tcp", protocol: "tcp"}, + {path: "/proc/net/tcp6", protocol: "tcp6"}, + {path: "/proc/net/udp", protocol: "udp"}, + {path: "/proc/net/udp6", protocol: "udp6"}, + } { + entries, err := readProcNet(spec.path, spec.protocol) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + for _, entry := range entries { + sockets[entry.Inode] = entry + } + } + + if len(sockets) == 0 { + return nil, nil + } + + metaByInode := resolveProcMetadata(sockets) + flows := make([]*localFlow, 0, len(metaByInode)) + for inode, socket := range sockets { + meta, ok := metaByInode[inode] + if !ok || meta.ProcessPath == "" { + continue + } + flows = append(flows, &localFlow{ + Protocol: socket.Protocol, + SrcIP: socket.SrcIP, + SrcPort: socket.SrcPort, + DstIP: socket.DstIP, + DstPort: socket.DstPort, + UID: socket.UID, + PID: meta.PID, + ProcessPath: meta.ProcessPath, + ProcessCwd: meta.ProcessCwd, + ProcessArgs: meta.ProcessArgs, + Inode: inode, + }) + } + + sort.Slice(flows, func(i, j int) bool { + return flows[i].key() < flows[j].key() + }) + return flows, nil +} + +func readProcNet(path, protocol string) ([]*procSocket, error) { + data, err := os.ReadFile(filepath.Clean(path)) + if err != nil { + return nil, err + } + + lines := strings.Split(string(data), "\n") + result := make([]*procSocket, 0, len(lines)) + for _, line := range lines[1:] { + fields := strings.Fields(strings.TrimSpace(line)) + if len(fields) < 10 { + continue + } + + srcIP, srcPort, err := decodeProcAddr(fields[1]) + if err != nil { + continue + } + dstIP, dstPort, err := decodeProcAddr(fields[2]) + if err != nil || dstPort == 0 || dstIP == "" { + continue + } + state := fields[3] + if strings.EqualFold(state, "0A") { + continue + } + + uid, err := strconv.ParseUint(fields[7], 10, 32) + if err != nil { + continue + } + inode := fields[9] + if inode == "" { + continue + } + + result = append(result, &procSocket{ + Protocol: protocol, + SrcIP: srcIP, + SrcPort: srcPort, + DstIP: dstIP, + DstPort: dstPort, + UID: uint32(uid), + Inode: inode, + }) + } + return result, nil +} + +func resolveProcMetadata(targets map[string]*procSocket) map[string]procMeta { + result := make(map[string]procMeta, len(targets)) + entries, err := os.ReadDir("/proc") + if err != nil { + return result + } + + for _, entry := range entries { + pid, err := strconv.Atoi(entry.Name()) + if err != nil { + continue + } + fdEntries, err := os.ReadDir(filepath.Join("/proc", entry.Name(), "fd")) + if err != nil { + continue + } + + for _, fdEntry := range fdEntries { + link, err := os.Readlink(filepath.Join("/proc", entry.Name(), "fd", fdEntry.Name())) + if err != nil || !strings.HasPrefix(link, "socket:[") { + continue + } + inode := strings.TrimSuffix(strings.TrimPrefix(link, "socket:["), "]") + if _, ok := targets[inode]; !ok { + continue + } + if _, exists := result[inode]; exists { + continue + } + result[inode] = procMeta{ + PID: pid, + ProcessPath: readProcLink(filepath.Join("/proc", entry.Name(), "exe")), + ProcessCwd: readProcLink(filepath.Join("/proc", entry.Name(), "cwd")), + ProcessArgs: readCmdline(filepath.Join("/proc", entry.Name(), "cmdline")), + } + } + } + + return result +} + +func readProcLink(path string) string { + value, err := os.Readlink(filepath.Clean(path)) + if err != nil { + return "" + } + return value +} + +func readCmdline(path string) []string { + data, err := os.ReadFile(filepath.Clean(path)) + if err != nil || len(data) == 0 { + return nil + } + parts := strings.Split(strings.TrimRight(string(data), "\x00"), "\x00") + result := make([]string, 0, len(parts)) + for _, part := range parts { + if trimmed := strings.TrimSpace(part); trimmed != "" { + result = append(result, trimmed) + } + } + return result +} + +func decodeProcAddr(value string) (string, uint32, error) { + parts := strings.Split(value, ":") + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid proc address %q", value) + } + + portValue, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return "", 0, err + } + + raw, err := hex.DecodeString(parts[0]) + if err != nil { + return "", 0, err + } + switch len(raw) { + case 4: + reverse(raw) + case 16: + for idx := 0; idx < len(raw); idx += 4 { + reverse(raw[idx : idx+4]) + } + default: + return "", 0, fmt.Errorf("unsupported proc address width %d", len(raw)) + } + + ip := net.IP(raw).String() + if ip == "" || ip == "::" || ip == "0.0.0.0" { + return "", 0, nil + } + return ip, uint32(portValue), nil +} + +func reverse(data []byte) { + for left, right := 0, len(data)-1; left < right; left, right = left+1, right-1 { + data[left], data[right] = data[right], data[left] + } +} diff --git a/cmd/opensnitchd-router/rules.go b/cmd/opensnitchd-router/rules.go new file mode 100644 index 0000000..e5b679b --- /dev/null +++ b/cmd/opensnitchd-router/rules.go @@ -0,0 +1,520 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "slices" + "sort" + "strconv" + "strings" + "time" + + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +type ruleEntry struct { + Rule *pb.Rule `json:"rule"` + AddedAt time.Time `json:"added_at"` +} + +func (d *daemon) loadRules() error { + data, err := os.ReadFile(filepath.Clean(d.rulesPath)) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var stored []*ruleEntry + if err := json.Unmarshal(data, &stored); err != nil { + return err + } + + now := time.Now() + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + for _, entry := range stored { + if entry == nil || entry.Rule == nil || strings.TrimSpace(entry.Rule.GetName()) == "" { + continue + } + if ruleExpired(entry, now) { + continue + } + d.rules[entry.Rule.GetName()] = entry + } + return nil +} + +func (d *daemon) snapshotRules() []*pb.Rule { + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + d.pruneExpiredRulesLocked(time.Now()) + + names := make([]string, 0, len(d.rules)) + for name := range d.rules { + names = append(names, name) + } + sort.Slice(names, func(i, j int) bool { + left := d.rules[names[i]] + right := d.rules[names[j]] + switch { + case left.Rule.GetPrecedence() != right.Rule.GetPrecedence(): + return left.Rule.GetPrecedence() + case !left.AddedAt.Equal(right.AddedAt): + return left.AddedAt.Before(right.AddedAt) + default: + return names[i] < names[j] + } + }) + + result := make([]*pb.Rule, 0, len(names)) + for _, name := range names { + result = append(result, cloneRule(d.rules[name].Rule)) + } + return result +} + +func (d *daemon) upsertRule(rule *pb.Rule, addedAt time.Time) error { + if rule == nil || strings.TrimSpace(rule.GetName()) == "" { + return nil + } + if addedAt.IsZero() { + addedAt = time.Now() + } + + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + d.rules[rule.GetName()] = &ruleEntry{ + Rule: cloneRule(rule), + AddedAt: addedAt, + } + d.pruneExpiredRulesLocked(time.Now()) + return d.saveRulesLocked() +} + +func (d *daemon) deleteRule(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return nil + } + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + delete(d.rules, name) + return d.saveRulesLocked() +} + +func (d *daemon) saveRulesLocked() error { + stored := make([]*ruleEntry, 0, len(d.rules)) + for _, entry := range d.rules { + if !ruleShouldPersist(entry.Rule) { + continue + } + stored = append(stored, &ruleEntry{ + Rule: cloneRule(entry.Rule), + AddedAt: entry.AddedAt, + }) + } + sort.Slice(stored, func(i, j int) bool { + return stored[i].Rule.GetName() < stored[j].Rule.GetName() + }) + + data, err := json.MarshalIndent(stored, "", " ") + if err != nil { + return err + } + return os.WriteFile(d.rulesPath, append(data, '\n'), 0600) +} + +func (d *daemon) pruneExpiredRulesLocked(now time.Time) { + dirty := false + for name, entry := range d.rules { + if ruleExpired(entry, now) { + delete(d.rules, name) + dirty = true + } + } + if dirty { + _ = d.saveRulesLocked() + } +} + +func (d *daemon) evaluateLocalFlow(flow *localFlow) (*pb.Rule, bool, error) { + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + d.pruneExpiredRulesLocked(time.Now()) + + names := d.sortedRuleNamesLocked() + for _, name := range names { + entry := d.rules[name] + if !ruleEnabled(entry.Rule) || !matchesLocalRule(entry.Rule, flow) { + continue + } + rule := cloneRule(entry.Rule) + if strings.EqualFold(rule.GetDuration(), "once") { + delete(d.rules, name) + _ = d.saveRulesLocked() + } + return rule, true, nil + } + return nil, false, nil +} + +func (d *daemon) evaluateForwardFlow(flow *forwardFlow) (*pb.Rule, bool) { + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + d.pruneExpiredRulesLocked(time.Now()) + + names := d.sortedRuleNamesLocked() + for _, name := range names { + entry := d.rules[name] + if !ruleEnabled(entry.Rule) || !matchesForwardRule(entry.Rule, flow) { + continue + } + rule := cloneRule(entry.Rule) + if strings.EqualFold(rule.GetDuration(), "once") { + delete(d.rules, name) + _ = d.saveRulesLocked() + } + return rule, true + } + return nil, false +} + +func (d *daemon) sortedRuleNamesLocked() []string { + names := make([]string, 0, len(d.rules)) + for name := range d.rules { + names = append(names, name) + } + sort.Slice(names, func(i, j int) bool { + left := d.rules[names[i]] + right := d.rules[names[j]] + switch { + case left.Rule.GetPrecedence() != right.Rule.GetPrecedence(): + return left.Rule.GetPrecedence() + case !left.AddedAt.Equal(right.AddedAt): + return left.AddedAt.Before(right.AddedAt) + default: + return names[i] < names[j] + } + }) + return names +} + +func (d *daemon) handleNotification(notif *pb.Notification) error { + switch notif.GetType() { + case pb.Action_CHANGE_RULE: + for _, rule := range notif.GetRules() { + if err := d.upsertRule(rule, ruleTimestamp(rule)); err != nil { + return err + } + } + return d.rebuildForwardRules() + case pb.Action_DELETE_RULE: + for _, rule := range notif.GetRules() { + if err := d.deleteRule(rule.GetName()); err != nil { + return err + } + } + return d.rebuildForwardRules() + case pb.Action_ENABLE_RULE, pb.Action_DISABLE_RULE: + enabled := notif.GetType() == pb.Action_ENABLE_RULE + for _, rule := range notif.GetRules() { + if err := d.setRuleEnabled(rule.GetName(), enabled); err != nil { + return err + } + } + return d.rebuildForwardRules() + case pb.Action_ENABLE_INTERCEPTION: + d.setInterceptionEnabled(true) + return nil + case pb.Action_DISABLE_INTERCEPTION: + d.setInterceptionEnabled(false) + return nil + case pb.Action_ENABLE_FIREWALL: + d.setFirewallEnabled(true) + return d.rebuildForwardRules() + case pb.Action_DISABLE_FIREWALL: + d.setFirewallEnabled(false) + return d.disableFirewall() + case pb.Action_RELOAD_FW_RULES: + return d.reloadFirewallState() + case pb.Action_CHANGE_CONFIG: + return d.applyConfigChange(notif.GetData()) + case pb.Action_STOP: + return fmt.Errorf("received stop notification") + default: + return nil + } +} + +func (d *daemon) applyConfigChange(raw string) error { + if strings.TrimSpace(raw) == "" { + return nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return err + } + + if value, ok := payload["default_action"].(string); ok && strings.TrimSpace(value) != "" { + d.defaultAction = strings.ToLower(strings.TrimSpace(value)) + } + return nil +} + +func (d *daemon) setRuleEnabled(name string, enabled bool) error { + name = strings.TrimSpace(name) + if name == "" { + return nil + } + + d.rulesMu.Lock() + defer d.rulesMu.Unlock() + + entry := d.rules[name] + if entry == nil || entry.Rule == nil { + return nil + } + entry.Rule.Enabled = enabled + return d.saveRulesLocked() +} + +func (d *daemon) setInterceptionEnabled(enabled bool) { + d.stateMu.Lock() + defer d.stateMu.Unlock() + d.interceptionEnabled = enabled +} + +func (d *daemon) setFirewallEnabled(enabled bool) { + d.stateMu.Lock() + defer d.stateMu.Unlock() + d.firewallEnabled = enabled +} + +func (d *daemon) isInterceptionEnabled() bool { + d.stateMu.RLock() + defer d.stateMu.RUnlock() + return d.interceptionEnabled +} + +func (d *daemon) isFirewallEnabled() bool { + d.stateMu.RLock() + defer d.stateMu.RUnlock() + return d.firewallEnabled +} + +func ruleEnabled(rule *pb.Rule) bool { + return rule != nil && rule.GetEnabled() +} + +func ruleShouldPersist(rule *pb.Rule) bool { + switch strings.ToLower(strings.TrimSpace(rule.GetDuration())) { + case "once", "until restart": + return false + default: + return true + } +} + +func ruleTimestamp(rule *pb.Rule) time.Time { + if rule != nil && rule.GetCreated() > 0 { + return time.Unix(rule.GetCreated(), 0) + } + return time.Now() +} + +func ruleExpired(entry *ruleEntry, now time.Time) bool { + if entry == nil || entry.Rule == nil { + return true + } + + addedAt := entry.AddedAt + if addedAt.IsZero() { + addedAt = time.Now() + } + + switch strings.ToLower(strings.TrimSpace(entry.Rule.GetDuration())) { + case "", "always", "until restart", "once": + return false + case "5m": + return now.After(addedAt.Add(5 * time.Minute)) + case "15m": + return now.After(addedAt.Add(15 * time.Minute)) + case "30m": + return now.After(addedAt.Add(30 * time.Minute)) + case "1h": + return now.After(addedAt.Add(time.Hour)) + default: + return false + } +} + +func cloneRule(rule *pb.Rule) *pb.Rule { + if rule == nil { + return nil + } + data, err := json.Marshal(rule) + if err != nil { + return rule + } + var cloned pb.Rule + if err := json.Unmarshal(data, &cloned); err != nil { + return rule + } + return &cloned +} + +func matchesLocalRule(rule *pb.Rule, flow *localFlow) bool { + return matchRule(rule, func(operand, value string) bool { + switch operand { + case "process.path": + return strings.TrimSpace(flow.ProcessPath) == value + case "dest.ip": + return strings.TrimSpace(flow.DstIP) == value + case "dest.port": + return strconv.Itoa(int(flow.DstPort)) == value + case "protocol": + return strings.EqualFold(flow.Protocol, value) + case "user.id": + return strconv.Itoa(int(flow.UID)) == value + default: + return false + } + }) +} + +func matchesForwardRule(rule *pb.Rule, flow *forwardFlow) bool { + return matchRule(rule, func(operand, value string) bool { + switch operand { + case "process.path": + return value == "device:"+flow.SrcIP + case "dest.ip": + return strings.TrimSpace(flow.DstIP) == value + case "dest.port": + return strconv.Itoa(int(flow.DstPort)) == value + case "protocol": + return strings.EqualFold(flow.Protocol, value) + case "user.id": + return value == "0" + default: + return false + } + }) +} + +func matchRule(rule *pb.Rule, match func(operand, value string) bool) bool { + if rule == nil || rule.GetOperator() == nil { + return false + } + return matchOperator(rule.GetOperator(), match) +} + +func matchOperator(operator *pb.Operator, match func(operand, value string) bool) bool { + if operator == nil { + return false + } + + operatorType := strings.ToLower(strings.TrimSpace(operator.GetType())) + if len(operator.GetList()) > 0 && operatorType == "" { + operatorType = "list" + } + + switch operatorType { + case "", "simple": + return match(strings.TrimSpace(operator.GetOperand()), strings.TrimSpace(operator.GetData())) + case "list": + for _, item := range operator.GetList() { + if !matchOperator(item, match) { + return false + } + } + return len(operator.GetList()) > 0 + default: + return false + } +} + +type forwardRuleSpec struct { + Name string + SourceIP string + DestIP string + Port string + Protocol string + Action string +} + +func compileForwardRule(rule *pb.Rule) *forwardRuleSpec { + if rule == nil || !ruleEnabled(rule) { + return nil + } + + spec := &forwardRuleSpec{ + Name: strings.TrimSpace(rule.GetName()), + Action: strings.ToLower(strings.TrimSpace(rule.GetAction())), + } + + var operands []struct { + operand string + data string + } + collectOperators(rule.GetOperator(), &operands) + for _, item := range operands { + switch item.operand { + case "process.path": + if !strings.HasPrefix(item.data, "device:") { + return nil + } + spec.SourceIP = strings.TrimPrefix(item.data, "device:") + case "dest.ip": + spec.DestIP = item.data + case "dest.port": + spec.Port = item.data + case "protocol": + spec.Protocol = strings.ToLower(item.data) + case "user.id": + default: + return nil + } + } + + if spec.SourceIP == "" { + return nil + } + if !slices.Contains([]string{"allow", "deny", "reject"}, spec.Action) { + return nil + } + return spec +} + +func collectOperators(operator *pb.Operator, out *[]struct { + operand string + data string +}) { + if operator == nil { + return + } + + operatorType := strings.ToLower(strings.TrimSpace(operator.GetType())) + if len(operator.GetList()) > 0 && operatorType == "" { + operatorType = "list" + } + + if operatorType == "list" { + for _, item := range operator.GetList() { + collectOperators(item, out) + } + return + } + + *out = append(*out, struct { + operand string + data string + }{ + operand: strings.TrimSpace(operator.GetOperand()), + data: strings.TrimSpace(operator.GetData()), + }) +} diff --git a/cmd/opensnitchd-router/stats.go b/cmd/opensnitchd-router/stats.go new file mode 100644 index 0000000..88d737e --- /dev/null +++ b/cmd/opensnitchd-router/stats.go @@ -0,0 +1,128 @@ +package main + +import ( + "strconv" + "strings" + "sync" + "time" + + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +type statsCollector struct { + mu sync.Mutex + + daemonStarted time.Time + connections uint64 + accepted uint64 + dropped uint64 + ruleHits uint64 + ruleMisses uint64 + + byProto map[string]uint64 + byAddress map[string]uint64 + byHost map[string]uint64 + byPort map[string]uint64 + byUID map[string]uint64 + byExecutable map[string]uint64 + events []*pb.Event +} + +func newStatsCollector() *statsCollector { + return &statsCollector{ + daemonStarted: time.Now(), + byProto: make(map[string]uint64), + byAddress: make(map[string]uint64), + byHost: make(map[string]uint64), + byPort: make(map[string]uint64), + byUID: make(map[string]uint64), + byExecutable: make(map[string]uint64), + } +} + +func (s *statsCollector) record(conn *pb.Connection, rule *pb.Rule, matched bool) { + if conn == nil { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.connections++ + if matched { + s.ruleHits++ + } else { + s.ruleMisses++ + } + + action := strings.ToLower(strings.TrimSpace(rule.GetAction())) + switch action { + case "deny", "reject": + s.dropped++ + default: + s.accepted++ + } + + proto := strings.ToLower(strings.TrimSpace(conn.GetProtocol())) + if proto != "" { + s.byProto[proto]++ + } + if dstIP := strings.TrimSpace(conn.GetDstIp()); dstIP != "" { + s.byAddress[dstIP]++ + } + if dstHost := strings.TrimSpace(conn.GetDstHost()); dstHost != "" { + s.byHost[dstHost]++ + } + if conn.GetDstPort() > 0 { + s.byPort[strconv.Itoa(int(conn.GetDstPort()))]++ + } + s.byUID[strconv.Itoa(int(conn.GetUserId()))]++ + if process := strings.TrimSpace(conn.GetProcessPath()); process != "" { + s.byExecutable[process]++ + } + + s.events = append(s.events, &pb.Event{ + Time: time.Now().Format("2006-01-02 15:04:05"), + Connection: conn, + Rule: cloneRule(rule), + Unixnano: time.Now().UnixNano(), + }) + if len(s.events) > 256 { + s.events = s.events[len(s.events)-256:] + } +} + +func (s *statsCollector) snapshot(ruleCount int) *pb.Statistics { + s.mu.Lock() + defer s.mu.Unlock() + + events := make([]*pb.Event, len(s.events)) + copy(events, s.events) + s.events = s.events[:0] + + return &pb.Statistics{ + DaemonVersion: daemonVersion(), + Rules: uint64(ruleCount), + Uptime: uint64(time.Since(s.daemonStarted).Seconds()), + Connections: s.connections, + Accepted: s.accepted, + Dropped: s.dropped, + RuleHits: s.ruleHits, + RuleMisses: s.ruleMisses, + ByProto: cloneCounterMap(s.byProto), + ByAddress: cloneCounterMap(s.byAddress), + ByHost: cloneCounterMap(s.byHost), + ByPort: cloneCounterMap(s.byPort), + ByUid: cloneCounterMap(s.byUID), + ByExecutable: cloneCounterMap(s.byExecutable), + Events: events, + } +} + +func cloneCounterMap(src map[string]uint64) map[string]uint64 { + out := make(map[string]uint64, len(src)) + for key, value := range src { + out[key] = value + } + return out +} diff --git a/config.yaml.example b/config.yaml.example index e5a297b..60b7d7a 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -1,6 +1,7 @@ server: http_addr: ":8080" grpc_addr: "0.0.0.0:50051" + grpc_public_addr: "" grpc_unix: "/tmp/osui.sock" database: diff --git a/internal/api/handlers_dns.go b/internal/api/handlers_dns.go index d93d80b..fd3472b 100644 --- a/internal/api/handlers_dns.go +++ b/internal/api/handlers_dns.go @@ -157,6 +157,14 @@ func (a *API) handleCreateDNSServerRules(w http.ResponseWriter, r *http.Request) }, }, } + if err := a.validateRouterManagedRuleTarget(req.Node, protoRule); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } dbRule, err := ruleutil.ProtoToDBRule(req.Node, now, protoRule) if err != nil { @@ -194,6 +202,14 @@ func (a *API) handleCreateDNSServerRules(w http.ResponseWriter, r *http.Request) Data: "53", }, } + if err := a.validateRouterManagedRuleTarget(req.Node, denyRule); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } dbDeny, err := ruleutil.ProtoToDBRule(req.Node, now, denyRule) if err != nil { @@ -351,6 +367,14 @@ func (a *API) handleSetDNSPolicy(w http.ResponseWriter, r *http.Request) { } protoRules := dnspolicy.BuildRules(policy) + if err := a.validateRouterManagedRuleSet(req.Node, protoRules); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } now := time.Now() dbRules := make([]*db.DBRule, 0, len(protoRules)) diff --git a/internal/api/handlers_nodes.go b/internal/api/handlers_nodes.go index 3be2992..584050d 100644 --- a/internal/api/handlers_nodes.go +++ b/internal/api/handlers_nodes.go @@ -1,6 +1,7 @@ package api import ( + "database/sql" "encoding/json" "net/http" "time" @@ -46,6 +47,18 @@ func (a *API) handleGetNodes(w http.ResponseWriter, r *http.Request) { // Enrich with live status from node manager liveNodes := a.nodes.GetAllNodes() + routers, err := a.db.GetRouters() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + linkedRouters := make(map[string]db.Router, len(routers)) + for _, router := range routers { + if router.DaemonMode == db.RouterDaemonModeRouterDaemon && router.LinkedNodeAddr != "" { + linkedRouters[router.LinkedNodeAddr] = router + } + } + type enrichedNode struct { Addr string `json:"addr"` Hostname string `json:"hostname"` @@ -59,6 +72,8 @@ func (a *API) handleGetNodes(w http.ResponseWriter, r *http.Request) { Online bool `json:"online"` Mode string `json:"mode"` SourceType string `json:"source_type"` + RouterManaged bool `json:"router_managed"` + LinkedRouterAddr string `json:"linked_router_addr"` Tags []string `json:"tags"` TemplateSyncPending bool `json:"template_sync_pending"` TemplateSyncError string `json:"template_sync_error"` @@ -87,6 +102,7 @@ func (a *API) handleGetNodes(w http.ResponseWriter, r *http.Request) { tags = []string{} } syncState := syncStates[n.Addr] + linkedRouter := linkedRouters[n.Addr] result[i] = enrichedNode{ Addr: n.Addr, Hostname: n.Hostname, @@ -100,6 +116,8 @@ func (a *API) handleGetNodes(w http.ResponseWriter, r *http.Request) { Online: online, Mode: n.Mode, SourceType: sourceType, + RouterManaged: linkedRouter.Addr != "", + LinkedRouterAddr: linkedRouter.Addr, Tags: tags, TemplateSyncPending: syncState.Pending, TemplateSyncError: syncState.Error, @@ -134,6 +152,16 @@ func (a *API) handleGetNode(w http.ResponseWriter, r *http.Request) { if sourceType == "" { sourceType = "opensnitch" } + linkedRouter, err := a.db.GetRouterByLinkedNodeAddr(addr) + if err != nil && err != sql.ErrNoRows { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + routerManaged := linkedRouter != nil && linkedRouter.DaemonMode == db.RouterDaemonModeRouterDaemon + linkedRouterAddr := "" + if linkedRouter != nil { + linkedRouterAddr = linkedRouter.Addr + } writeJSON(w, http.StatusOK, map[string]interface{}{ "addr": node.Addr, "hostname": node.Hostname, @@ -147,6 +175,8 @@ func (a *API) handleGetNode(w http.ResponseWriter, r *http.Request) { "last_connection": node.LastConn, "mode": node.Mode, "source_type": sourceType, + "router_managed": routerManaged, + "linked_router_addr": linkedRouterAddr, "tags": tags, "template_sync_pending": syncState.Pending, "template_sync_error": syncState.Error, @@ -176,6 +206,9 @@ func (a *API) handleReplaceNodeTags(w http.ResponseWriter, r *http.Request) { if a.templateSync != nil { if err := a.templateSync.ReconcileNode(addr); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } diff --git a/internal/api/handlers_prompt.go b/internal/api/handlers_prompt.go index 5da29ba..f2e8564 100644 --- a/internal/api/handlers_prompt.go +++ b/internal/api/handlers_prompt.go @@ -21,42 +21,49 @@ func (a *API) handleGetPendingPrompts(w http.ResponseWriter, r *http.Request) { pending := a.prompter.GetPending() type promptResponse struct { - ID string `json:"id"` - NodeAddr string `json:"node_addr"` - CreatedAt string `json:"created_at"` - Process string `json:"process"` - DstHost string `json:"dst_host"` - DstIP string `json:"dst_ip"` - DstPort uint32 `json:"dst_port"` - Protocol string `json:"protocol"` - SrcIP string `json:"src_ip"` - SrcPort uint32 `json:"src_port"` - UID uint32 `json:"uid"` - PID uint32 `json:"pid"` - Args []string `json:"args"` - Cwd string `json:"cwd"` - Checksums interface{} `json:"checksums"` + ID string `json:"id"` + NodeAddr string `json:"node_addr"` + CreatedAt string `json:"created_at"` + Process string `json:"process"` + DstHost string `json:"dst_host"` + DstIP string `json:"dst_ip"` + DstPort uint32 `json:"dst_port"` + Protocol string `json:"protocol"` + SrcIP string `json:"src_ip"` + SrcPort uint32 `json:"src_port"` + UID uint32 `json:"uid"` + PID uint32 `json:"pid"` + Args []string `json:"args"` + Cwd string `json:"cwd"` + Checksums interface{} `json:"checksums"` + RouterManaged bool `json:"router_managed"` } result := make([]promptResponse, 0, len(pending)) for _, p := range pending { conn := p.Connection + routerManaged, err := a.isRouterManagedNode(p.NodeAddr) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } result = append(result, promptResponse{ - ID: p.ID, - NodeAddr: p.NodeAddr, - CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"), - Process: conn.GetProcessPath(), - DstHost: conn.GetDstHost(), - DstIP: conn.GetDstIp(), - DstPort: conn.GetDstPort(), - Protocol: conn.GetProtocol(), - SrcIP: conn.GetSrcIp(), - SrcPort: conn.GetSrcPort(), - UID: conn.GetUserId(), - PID: conn.GetProcessId(), - Args: conn.GetProcessArgs(), - Cwd: conn.GetProcessCwd(), - Checksums: conn.GetProcessChecksums(), + ID: p.ID, + NodeAddr: p.NodeAddr, + CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"), + Process: conn.GetProcessPath(), + DstHost: conn.GetDstHost(), + DstIP: conn.GetDstIp(), + DstPort: conn.GetDstPort(), + Protocol: conn.GetProtocol(), + SrcIP: conn.GetSrcIp(), + SrcPort: conn.GetSrcPort(), + UID: conn.GetUserId(), + PID: conn.GetProcessId(), + Args: conn.GetProcessArgs(), + Cwd: conn.GetProcessCwd(), + Checksums: conn.GetProcessChecksums(), + RouterManaged: routerManaged, }) } @@ -65,6 +72,11 @@ func (a *API) handleGetPendingPrompts(w http.ResponseWriter, r *http.Request) { func (a *API) handlePromptReply(w http.ResponseWriter, r *http.Request) { promptID := chi.URLParam(r, "id") + pending := a.prompter.GetPendingPrompt(promptID) + if pending == nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "prompt not found or expired"}) + return + } var req promptReplyRequest if err := readJSON(r, &req); err != nil { @@ -89,6 +101,11 @@ func (a *API) handlePromptReply(w http.ResponseWriter, r *http.Request) { }, } + if err := a.validateRouterManagedRuleTarget(pending.NodeAddr, rule); err != nil { + writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()}) + return + } + if err := a.prompter.Reply(promptID, rule); err != nil { writeJSON(w, http.StatusNotFound, map[string]string{"error": err.Error()}) return diff --git a/internal/api/handlers_routers.go b/internal/api/handlers_routers.go index cb41cdc..9d829ea 100644 --- a/internal/api/handlers_routers.go +++ b/internal/api/handlers_routers.go @@ -15,6 +15,26 @@ import ( "github.com/bilalbayram/opensnitch-web/internal/ws" ) +type routerSSHRequest struct { + SSHPass string `json:"ssh_pass"` + SSHKey string `json:"ssh_key"` +} + +type routerUpgradeRequest struct { + SSHPass string `json:"ssh_pass"` + SSHKey string `json:"ssh_key"` + NodeName string `json:"node_name"` + DefaultAction string `json:"default_action"` + PollIntervalMS int `json:"poll_interval_ms"` + FirewallBackend string `json:"firewall_backend"` +} + +type routerDowngradeRequest struct { + SSHPass string `json:"ssh_pass"` + SSHKey string `json:"ssh_key"` + ServerURL string `json:"server_url"` +} + func (a *API) handleConnectRouter(w http.ResponseWriter, r *http.Request) { var req router.ConnectRequest if err := readJSON(r, &req); err != nil { @@ -98,19 +118,46 @@ func (a *API) handleGetRouters(w http.ResponseWriter, r *http.Request) { } // Enrich with node online status + type linkedNodeSummary struct { + Addr string `json:"addr"` + Online bool `json:"online"` + Mode string `json:"mode"` + DaemonVersion string `json:"daemon_version"` + DaemonRules int64 `json:"daemon_rules"` + Cons int64 `json:"cons"` + ConsDropped int64 `json:"cons_dropped"` + LastConnection string `json:"last_connection"` + } type enrichedRouter struct { db.Router - Online bool `json:"online"` + Online bool `json:"online"` + LinkedNode *linkedNodeSummary `json:"linked_node"` } result := make([]enrichedRouter, len(routers)) for i, rt := range routers { - node, err := a.db.GetNode(rt.Addr) + nodeAddr := rt.Addr + if rt.DaemonMode == db.RouterDaemonModeRouterDaemon && strings.TrimSpace(rt.LinkedNodeAddr) != "" { + nodeAddr = rt.LinkedNodeAddr + } + + node, err := a.db.GetNode(nodeAddr) online := false + var linkedNode *linkedNodeSummary if err == nil && node != nil { online = routerOnlineFromLastConn(node.LastConn) + linkedNode = &linkedNodeSummary{ + Addr: nodeAddr, + Online: online, + Mode: node.Mode, + DaemonVersion: node.DaemonVersion, + DaemonRules: node.DaemonRules, + Cons: node.Cons, + ConsDropped: node.ConsDropped, + LastConnection: node.LastConn, + } } - result[i] = enrichedRouter{Router: rt, Online: online} + result[i] = enrichedRouter{Router: rt, Online: online, LinkedNode: linkedNode} } writeJSON(w, http.StatusOK, result) @@ -215,3 +262,165 @@ func (a *API) handleSuggestServerURL(w http.ResponseWriter, r *http.Request) { } writeJSON(w, http.StatusOK, resp) } + +func (a *API) handleRouterCapabilities(w http.ResponseWriter, r *http.Request) { + addr := routerAddrParam(r) + + var req routerSSHRequest + if err := readJSON(r, &req); err != nil || (req.SSHPass == "" && req.SSHKey == "") { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "ssh_pass or ssh_key is required"}) + return + } + + rt, err := a.db.GetRouterByAddr(addr) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "router not found"}) + return + } + + result, err := a.routerProv.CheckCapabilities(r.Context(), rt.Addr, rt.SSHPort, rt.SSHUser, req.SSHPass, req.SSHKey) + if err != nil { + var capabilities any + var steps any + if result != nil { + capabilities = result.Capabilities + steps = result.Steps + } + writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": err.Error(), + "capabilities": capabilities, + "steps": steps, + }) + return + } + + writeJSON(w, http.StatusOK, result) +} + +func (a *API) handleUpgradeRouter(w http.ResponseWriter, r *http.Request) { + addr := routerAddrParam(r) + + var req routerUpgradeRequest + if err := readJSON(r, &req); err != nil || (req.SSHPass == "" && req.SSHKey == "") { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "ssh_pass or ssh_key is required"}) + return + } + + rt, err := a.db.GetRouterByAddr(addr) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "router not found"}) + return + } + if rt.DaemonMode == db.RouterDaemonModeRouterDaemon { + writeJSON(w, http.StatusConflict, map[string]string{"error": "router is already running router-daemon"}) + return + } + + result, err := a.routerProv.ProvisionDaemon(r.Context(), router.DaemonRequest{ + Addr: rt.Addr, + SSHPort: rt.SSHPort, + SSHUser: rt.SSHUser, + SSHPass: req.SSHPass, + SSHKey: req.SSHKey, + NodeName: req.NodeName, + DefaultAction: req.DefaultAction, + PollIntervalMS: req.PollIntervalMS, + FirewallBackend: req.FirewallBackend, + }) + if err != nil { + var capabilities any + var steps any + if result != nil { + capabilities = result.Capabilities + steps = result.Steps + } + writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": err.Error(), + "capabilities": capabilities, + "steps": steps, + }) + return + } + + writeJSON(w, http.StatusOK, result) +} + +func (a *API) handleDowngradeRouter(w http.ResponseWriter, r *http.Request) { + addr := routerAddrParam(r) + + var req routerDowngradeRequest + if err := readJSON(r, &req); err != nil || (req.SSHPass == "" && req.SSHKey == "") { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "ssh_pass or ssh_key is required"}) + return + } + + rt, err := a.db.GetRouterByAddr(addr) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "router not found"}) + return + } + if rt.DaemonMode != db.RouterDaemonModeRouterDaemon { + writeJSON(w, http.StatusConflict, map[string]string{"error": "router is not running router-daemon"}) + return + } + + deprovisionSteps, err := a.routerProv.DeprovisionDaemon(r.Context(), rt.Addr, rt.SSHPort, rt.SSHUser, req.SSHPass, req.SSHKey) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": err.Error(), + "steps": deprovisionSteps, + }) + return + } + + serverURL := strings.TrimSpace(req.ServerURL) + if serverURL == "" { + if lanURL, _ := router.ResolveServerURL(rt.Addr, a.cfg.Server.HTTPAddr); lanURL != "" { + serverURL = lanURL + } else { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + serverURL = fmt.Sprintf("%s://%s", scheme, r.Host) + } + } + + result, provisionErr := a.routerProv.Provision(r.Context(), router.ConnectRequest{ + Addr: rt.Addr, + SSHPort: rt.SSHPort, + SSHUser: rt.SSHUser, + SSHPass: req.SSHPass, + SSHKey: req.SSHKey, + Name: rt.Name, + LANSubnet: rt.LANSubnet, + ServerURL: serverURL, + }) + if provisionErr != nil { + steps := append([]router.ProvisionStep{}, deprovisionSteps...) + if result != nil { + steps = append(steps, result.Steps...) + } + writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": provisionErr.Error(), + "steps": steps, + }) + return + } + + if err := a.db.UpdateRouterRuntime(rt.Addr, db.RouterDaemonModeConntrackAgent, ""); err != nil { + log.Printf("[router] Failed to update router runtime for %s: %v", rt.Addr, err) + } + _ = a.db.UpsertRouterNode(rt.Addr, rt.Name, "conntrack-agent", db.NodeStatusOnline, time.Now().Format("2006-01-02 15:04:05")) + + result.Steps = append(deprovisionSteps, result.Steps...) + writeJSON(w, http.StatusOK, result) +} + +func routerAddrParam(r *http.Request) string { + addr := chi.URLParam(r, "addr") + if decoded, err := url.QueryUnescape(addr); err == nil { + return decoded + } + return addr +} diff --git a/internal/api/handlers_routers_test.go b/internal/api/handlers_routers_test.go index ace81f5..35e0b4c 100644 --- a/internal/api/handlers_routers_test.go +++ b/internal/api/handlers_routers_test.go @@ -29,6 +29,18 @@ func (s *stubRouterProvisioner) Deprovision(ctx context.Context, addr string, ss return s.steps, s.err } +func (s *stubRouterProvisioner) CheckCapabilities(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) (*routerpkg.CapabilityCheckResult, error) { + return &routerpkg.CapabilityCheckResult{}, s.err +} + +func (s *stubRouterProvisioner) ProvisionDaemon(ctx context.Context, req routerpkg.DaemonRequest) (*routerpkg.ProvisionResult, error) { + return &routerpkg.ProvisionResult{Steps: s.steps}, s.err +} + +func (s *stubRouterProvisioner) DeprovisionDaemon(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) ([]routerpkg.ProvisionStep, error) { + return s.steps, s.err +} + func seedRouterRecord(t *testing.T, env *apiTestEnv, addr string, lastConn time.Time) { t.Helper() diff --git a/internal/api/handlers_rules.go b/internal/api/handlers_rules.go index fafe921..f1ced25 100644 --- a/internal/api/handlers_rules.go +++ b/internal/api/handlers_rules.go @@ -129,6 +129,10 @@ func (a *API) handleCreateRule(w http.ResponseWriter, r *http.Request) { } protoRule := buildSimpleRule(req) + if err := a.validateRouterManagedRuleTarget(req.Node, protoRule); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) + return + } dbRule, err := ruleutil.ProtoToDBRule(req.Node, time.Now(), protoRule) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) @@ -192,6 +196,10 @@ func (a *API) handleUpdateRule(w http.ResponseWriter, r *http.Request) { } protoRule := buildSimpleRule(req) + if err := a.validateRouterManagedRuleTarget(req.Node, protoRule); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) + return + } dbRule, err := ruleutil.ProtoToDBRule(req.Node, time.Now(), protoRule) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) @@ -285,6 +293,10 @@ func (a *API) handleGenerateRulesApply(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } + if err := a.validateRouterManagedRuleTarget(filters.Node, protoRule); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) + return + } dbRule, err := ruleutil.ProtoToDBRule(filters.Node, time.Now(), protoRule) if err != nil { diff --git a/internal/api/handlers_templates.go b/internal/api/handlers_templates.go index a40c8e9..f12b241 100644 --- a/internal/api/handlers_templates.go +++ b/internal/api/handlers_templates.go @@ -262,8 +262,17 @@ func (a *API) handleCreateTemplateRule(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request"}) return } + candidate := buildTemplateRuleRecord(templateID, 0, req) + if err := a.validateTemplateRuleForManagedTargets(templateID, candidate); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } - templateRule, err := a.db.CreateTemplateRule(buildTemplateRuleRecord(templateID, 0, req)) + templateRule, err := a.db.CreateTemplateRule(candidate) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -271,6 +280,9 @@ func (a *API) handleCreateTemplateRule(w http.ResponseWriter, r *http.Request) { if a.templateSync != nil { if err := a.templateSync.ReconcileTemplate(templateID); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } @@ -321,6 +333,14 @@ func (a *API) handleUpdateTemplateRule(w http.ResponseWriter, r *http.Request) { } templateRule := buildTemplateRuleRecord(templateID, ruleID, req) + if err := a.validateTemplateRuleForManagedTargets(templateID, templateRule); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } if err := a.db.UpdateTemplateRule(templateRule); err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return @@ -328,6 +348,9 @@ func (a *API) handleUpdateTemplateRule(w http.ResponseWriter, r *http.Request) { if a.templateSync != nil { if err := a.templateSync.ReconcileTemplate(templateID); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } @@ -375,6 +398,9 @@ func (a *API) handleDeleteTemplateRule(w http.ResponseWriter, r *http.Request) { if a.templateSync != nil { if err := a.templateSync.ReconcileTemplate(templateID); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } @@ -407,6 +433,14 @@ func (a *API) handleCreateTemplateAttachment(w http.ResponseWriter, r *http.Requ writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } + if err := a.validateTemplateAttachmentForManagedTargets(templateID, req.TargetType, req.TargetRef); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } attachment, err := a.db.CreateTemplateAttachment(&db.TemplateAttachment{ TemplateID: templateID, @@ -421,6 +455,9 @@ func (a *API) handleCreateTemplateAttachment(w http.ResponseWriter, r *http.Requ if a.templateSync != nil { if err := a.templateSync.ReconcileTemplate(templateID); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } @@ -467,6 +504,14 @@ func (a *API) handleUpdateTemplateAttachment(w http.ResponseWriter, r *http.Requ writeJSON(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } + if err := a.validateTemplateAttachmentForManagedTargets(templateID, req.TargetType, req.TargetRef); err != nil { + status := http.StatusInternalServerError + if ruleutil.IsRouterManagedRuleError(err) { + status = http.StatusConflict + } + writeJSON(w, status, map[string]string{"error": err.Error()}) + return + } beforeNodes := []string{} if a.templateSync != nil { @@ -487,6 +532,9 @@ func (a *API) handleUpdateTemplateAttachment(w http.ResponseWriter, r *http.Requ if a.templateSync != nil { afterNodes, _ := a.templateSync.AffectedNodesForTemplates([]int64{templateID}) if err := a.templateSync.ReconcileNodes(append(beforeNodes, afterNodes...)); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } @@ -534,6 +582,9 @@ func (a *API) handleDeleteTemplateAttachment(w http.ResponseWriter, r *http.Requ if a.templateSync != nil { if err := a.templateSync.ReconcileNodes(affectedNodes); err != nil { + if writeRouterManagedSyncError(w, err) { + return + } writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } diff --git a/internal/api/router.go b/internal/api/router.go index 058528f..0f22bb4 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -25,6 +25,9 @@ import ( type routerProvisioner interface { Provision(ctx context.Context, req router.ConnectRequest) (*router.ProvisionResult, error) Deprovision(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) ([]router.ProvisionStep, error) + CheckCapabilities(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) (*router.CapabilityCheckResult, error) + ProvisionDaemon(ctx context.Context, req router.DaemonRequest) (*router.ProvisionResult, error) + DeprovisionDaemon(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) ([]router.ProvisionStep, error) } type API struct { @@ -50,7 +53,7 @@ func NewRouter(cfg *config.Config, database *db.Database, nodes *nodemanager.Man templateSync: templateSync, fetcher: blocklist.NewFetcher(), geoResolver: geo, - routerProv: router.NewProvisioner(database), + routerProv: router.NewProvisioner(database).WithRuntime(nodes, cfg), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, @@ -166,6 +169,9 @@ func NewRouter(cfg *config.Config, database *db.Database, nodes *nodemanager.Man r.Post("/api/v1/routers/suggest-url", api.handleSuggestServerURL) r.Post("/api/v1/routers/connect", api.handleConnectRouter) r.Get("/api/v1/routers", api.handleGetRouters) + r.Post("/api/v1/routers/{addr}/capabilities", api.handleRouterCapabilities) + r.Post("/api/v1/routers/{addr}/upgrade", api.handleUpgradeRouter) + r.Post("/api/v1/routers/{addr}/downgrade", api.handleDowngradeRouter) r.Post("/api/v1/routers/{addr}/disconnect", api.handleDisconnectRouter) // Prompts diff --git a/internal/api/router_managed.go b/internal/api/router_managed.go new file mode 100644 index 0000000..c98cdbb --- /dev/null +++ b/internal/api/router_managed.go @@ -0,0 +1,62 @@ +package api + +import ( + "database/sql" + "strings" + + "github.com/bilalbayram/opensnitch-web/internal/db" + ruleutil "github.com/bilalbayram/opensnitch-web/internal/rules" + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +func (a *API) isRouterManagedNode(addr string) (bool, error) { + addr = strings.TrimSpace(addr) + if addr == "" { + return false, nil + } + + router, err := a.db.GetRouterByLinkedNodeAddr(addr) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + + return router.DaemonMode == db.RouterDaemonModeRouterDaemon, nil +} + +func (a *API) hasManagedRouterTargets(node string) (bool, error) { + node = strings.TrimSpace(node) + if node != "" { + return a.isRouterManagedNode(node) + } + + routers, err := a.db.GetRouters() + if err != nil { + return false, err + } + for _, router := range routers { + if router.DaemonMode == db.RouterDaemonModeRouterDaemon && strings.TrimSpace(router.LinkedNodeAddr) != "" { + return true, nil + } + } + return false, nil +} + +func (a *API) validateRouterManagedRuleTarget(node string, rule *pb.Rule) error { + needsValidation, err := a.hasManagedRouterTargets(node) + if err != nil || !needsValidation { + return err + } + return ruleutil.ValidateRouterManagedRule(rule) +} + +func (a *API) validateRouterManagedRuleSet(node string, rules []*pb.Rule) error { + for _, rule := range rules { + if err := a.validateRouterManagedRuleTarget(node, rule); err != nil { + return err + } + } + return nil +} diff --git a/internal/api/router_managed_test.go b/internal/api/router_managed_test.go new file mode 100644 index 0000000..32e6ee9 --- /dev/null +++ b/internal/api/router_managed_test.go @@ -0,0 +1,225 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi/v5" + + "github.com/bilalbayram/opensnitch-web/internal/db" + "github.com/bilalbayram/opensnitch-web/internal/prompter" + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +func performJSONRequestWithPromptID(t *testing.T, handler http.HandlerFunc, method, target, promptID string, payload any) *httptest.ResponseRecorder { + t.Helper() + + var body bytes.Buffer + if payload != nil { + if err := json.NewEncoder(&body).Encode(payload); err != nil { + t.Fatalf("encode request body: %v", err) + } + } + + req := httptest.NewRequest(method, target, &body) + req.Header.Set("Content-Type", "application/json") + + routeCtx := chi.NewRouteContext() + routeCtx.URLParams.Add("id", promptID) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, routeCtx)) + + rec := httptest.NewRecorder() + handler(rec, req) + return rec +} + +func seedManagedRouterNode(t *testing.T, env *apiTestEnv, addr string) { + t.Helper() + + env.seedNode(t, addr, true) + if err := env.database.InsertRouter(&db.Router{ + Name: addr, + Addr: addr, + SSHPort: 22, + SSHUser: "root", + APIKey: "router-key-" + addr, + LANSubnet: "192.168.1.0/24", + DaemonMode: db.RouterDaemonModeRouterDaemon, + LinkedNodeAddr: addr, + Status: "active", + }); err != nil { + t.Fatalf("insert router: %v", err) + } +} + +func TestHandleGetPendingPromptsMarksRouterManaged(t *testing.T) { + env := newAPITestEnv(t) + seedManagedRouterNode(t, env, "router-a") + + promptCh := make(chan *prompter.PendingPrompt, 1) + env.prompter.OnNewPrompt = func(prompt *prompter.PendingPrompt) { + promptCh <- prompt + } + + go func() { + _, _ = env.prompter.AskUser("router-a", &pb.Connection{ + ProcessPath: "/usr/bin/curl", + DstHost: "example.com", + DstIp: "93.184.216.34", + DstPort: 443, + Protocol: "tcp", + }) + }() + + prompt := <-promptCh + rec := performJSONRequest(t, env.api.handleGetPendingPrompts, http.MethodGet, "/api/v1/prompts/pending", nil) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + response := decodeJSON[[]struct { + ID string `json:"id"` + RouterManaged bool `json:"router_managed"` + }](t, rec) + if len(response) != 1 { + t.Fatalf("expected 1 prompt, got %d", len(response)) + } + if response[0].ID != prompt.ID || !response[0].RouterManaged { + t.Fatalf("unexpected prompt payload: %+v", response[0]) + } + + if err := env.prompter.Reply(prompt.ID, &pb.Rule{ + Name: "cleanup", + Action: "allow", + Duration: "once", + Enabled: true, + Operator: &pb.Operator{Type: "simple", Operand: "process.path", Data: "/usr/bin/curl"}, + }); err != nil { + t.Fatalf("cleanup reply: %v", err) + } +} + +func TestHandlePromptReplyRejectsUnsupportedRouterManagedOperand(t *testing.T) { + env := newAPITestEnv(t) + seedManagedRouterNode(t, env, "router-a") + + promptCh := make(chan *prompter.PendingPrompt, 1) + env.prompter.OnNewPrompt = func(prompt *prompter.PendingPrompt) { + promptCh <- prompt + } + + go func() { + _, _ = env.prompter.AskUser("router-a", &pb.Connection{ + ProcessPath: "/usr/bin/curl", + DstHost: "example.com", + DstIp: "93.184.216.34", + DstPort: 443, + Protocol: "tcp", + }) + }() + + prompt := <-promptCh + rec := performJSONRequestWithPromptID(t, env.api.handlePromptReply, http.MethodPost, "/api/v1/prompts/"+prompt.ID+"/reply", prompt.ID, promptReplyRequest{ + Action: "deny", + Duration: "always", + Name: "router-rule", + Operand: "dest.host", + Data: "example.com", + Operator: "simple", + }) + if rec.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", rec.Code, rec.Body.String()) + } + + if err := env.prompter.Reply(prompt.ID, &pb.Rule{ + Name: "cleanup", + Action: "allow", + Duration: "once", + Enabled: true, + Operator: &pb.Operator{Type: "simple", Operand: "process.path", Data: "/usr/bin/curl"}, + }); err != nil { + t.Fatalf("cleanup reply: %v", err) + } +} + +func TestHandleSetDNSPolicyRejectsUnsupportedRouterManagedHostRules(t *testing.T) { + env := newAPITestEnv(t) + seedManagedRouterNode(t, env, "router-a") + + rec := performJSONRequest(t, env.api.handleSetDNSPolicy, http.MethodPost, "/api/v1/dns/policy", dnsPolicyRequest{ + Node: "router-a", + Enabled: true, + AllowedResolvers: []string{"1.1.1.1"}, + BlockDoHHostnames: true, + }) + if rec.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestHandleCreateRuleRejectsUnsupportedRouterManagedOperand(t *testing.T) { + env := newAPITestEnv(t) + seedManagedRouterNode(t, env, "router-a") + + rec := performJSONRequest(t, env.api.handleCreateRule, http.MethodPost, "/api/v1/rules", ruleRequest{ + Name: "router-host-rule", + Node: "router-a", + Enabled: true, + Action: "deny", + Duration: "always", + OperatorType: "simple", + OperatorOperand: "dest.host", + OperatorData: "example.com", + }) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestHandlePromptReplyAllowsSupportedRouterManagedOperand(t *testing.T) { + env := newAPITestEnv(t) + seedManagedRouterNode(t, env, "router-a") + + promptCh := make(chan *prompter.PendingPrompt, 1) + resultCh := make(chan *prompter.AskResult, 1) + env.prompter.OnNewPrompt = func(prompt *prompter.PendingPrompt) { + promptCh <- prompt + } + + go func() { + result, _ := env.prompter.AskUser("router-a", &pb.Connection{ + ProcessPath: "/usr/bin/curl", + DstIp: "93.184.216.34", + DstPort: 443, + Protocol: "tcp", + }) + resultCh <- result + }() + + prompt := <-promptCh + rec := performJSONRequestWithPromptID(t, env.api.handlePromptReply, http.MethodPost, "/api/v1/prompts/"+prompt.ID+"/reply", prompt.ID, promptReplyRequest{ + Action: "deny", + Duration: "always", + Name: "router-ip-rule", + Operand: "dest.ip", + Data: "93.184.216.34", + Operator: "simple", + }) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + select { + case result := <-resultCh: + if result == nil || result.Rule == nil || result.Rule.GetOperator().GetOperand() != "dest.ip" { + t.Fatalf("unexpected prompt result: %+v", result) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for prompt result") + } +} diff --git a/internal/api/template_router_managed.go b/internal/api/template_router_managed.go new file mode 100644 index 0000000..fcf7c10 --- /dev/null +++ b/internal/api/template_router_managed.go @@ -0,0 +1,133 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/bilalbayram/opensnitch-web/internal/db" + ruleutil "github.com/bilalbayram/opensnitch-web/internal/rules" +) + +func (a *API) validateTemplateRuleForManagedTargets(templateID int64, candidate *db.TemplateRule) error { + managedNodes, err := a.managedNodesForTemplate(templateID) + if err != nil || len(managedNodes) == 0 { + return err + } + + protoRule, err := ruleutil.TemplateRuleToProto(candidate) + if err != nil { + return err + } + return ruleutil.ValidateRouterManagedRule(protoRule) +} + +func (a *API) validateTemplateAttachmentForManagedTargets(templateID int64, targetType, targetRef string) error { + targetNodes, err := a.resolveAttachmentTargetNodes(targetType, targetRef) + if err != nil { + return err + } + + managedNodes, err := a.filterManagedNodes(targetNodes) + if err != nil || len(managedNodes) == 0 { + return err + } + + rules, err := a.db.GetTemplateRules(templateID) + if err != nil { + return err + } + for i := range rules { + protoRule, err := ruleutil.TemplateRuleToProto(&rules[i]) + if err != nil { + return err + } + if err := ruleutil.ValidateRouterManagedRule(protoRule); err != nil { + return err + } + } + return nil +} + +func (a *API) managedNodesForTemplate(templateID int64) ([]string, error) { + attachments, err := a.db.GetTemplateAttachments(templateID) + if err != nil { + return nil, err + } + + targetNodes := make([]string, 0, len(attachments)) + for _, attachment := range attachments { + nodes, err := a.resolveAttachmentTargetNodes(attachment.TargetType, attachment.TargetRef) + if err != nil { + return nil, err + } + targetNodes = append(targetNodes, nodes...) + } + + return a.filterManagedNodes(targetNodes) +} + +func (a *API) resolveAttachmentTargetNodes(targetType, targetRef string) ([]string, error) { + targetType = strings.TrimSpace(targetType) + targetRef = strings.TrimSpace(targetRef) + + switch targetType { + case "node": + if targetRef == "" { + return nil, nil + } + return []string{targetRef}, nil + case "tag": + allTags, err := a.db.GetAllNodeTags() + if err != nil { + return nil, err + } + + result := make([]string, 0) + for node, tags := range allTags { + for _, tag := range tags { + if tag == targetRef { + result = append(result, node) + break + } + } + } + return result, nil + default: + return nil, nil + } +} + +func (a *API) filterManagedNodes(nodes []string) ([]string, error) { + seen := make(map[string]struct{}, len(nodes)) + result := make([]string, 0, len(nodes)) + for _, node := range nodes { + node = strings.TrimSpace(node) + if node == "" { + continue + } + if _, ok := seen[node]; ok { + continue + } + seen[node] = struct{}{} + + managed, err := a.isRouterManagedNode(node) + if err != nil { + return nil, err + } + if managed { + result = append(result, node) + } + } + return result, nil +} + +func writeRouterManagedSyncError(w http.ResponseWriter, err error) bool { + if err == nil { + return false + } + if ruleutil.IsRouterManagedRuleError(err) { + writeJSON(w, http.StatusConflict, map[string]string{"error": err.Error()}) + return true + } + return false +} diff --git a/internal/config/config.go b/internal/config/config.go index 3536d15..9f9f29e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,9 +21,10 @@ type Config struct { } type ServerConfig struct { - HTTPAddr string `yaml:"http_addr"` - GRPCAddr string `yaml:"grpc_addr"` - GRPCUnix string `yaml:"grpc_unix"` + HTTPAddr string `yaml:"http_addr"` + GRPCAddr string `yaml:"grpc_addr"` + GRPCPublicAddr string `yaml:"grpc_public_addr"` + GRPCUnix string `yaml:"grpc_unix"` } type DatabaseConfig struct { @@ -51,9 +52,10 @@ type GeoIPConfig struct { func DefaultConfig() *Config { return &Config{ Server: ServerConfig{ - HTTPAddr: ":8080", - GRPCAddr: "0.0.0.0:50051", - GRPCUnix: "/tmp/osui.sock", + HTTPAddr: ":8080", + GRPCAddr: "0.0.0.0:50051", + GRPCPublicAddr: "", + GRPCUnix: "/tmp/osui.sock", }, Database: DatabaseConfig{ Path: "./opensnitch-web.db", diff --git a/internal/db/routers.go b/internal/db/routers.go index d92fbf0..fd5d9c3 100644 --- a/internal/db/routers.go +++ b/internal/db/routers.go @@ -5,21 +5,29 @@ import ( "database/sql" "encoding/hex" "fmt" + "strings" ) type Router struct { - ID int64 `json:"id"` - Name string `json:"name"` - Addr string `json:"addr"` - SSHPort int `json:"ssh_port"` - SSHUser string `json:"ssh_user"` - APIKey string `json:"-"` - LANSubnet string `json:"lan_subnet"` - Status string `json:"status"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Addr string `json:"addr"` + SSHPort int `json:"ssh_port"` + SSHUser string `json:"ssh_user"` + APIKey string `json:"-"` + LANSubnet string `json:"lan_subnet"` + DaemonMode string `json:"daemon_mode"` + LinkedNodeAddr string `json:"linked_node_addr"` + Status string `json:"status"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } +const ( + RouterDaemonModeConntrackAgent = "conntrack-agent" + RouterDaemonModeRouterDaemon = "router-daemon" +) + func GenerateAPIKey() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { @@ -32,10 +40,14 @@ func (d *Database) InsertRouter(r *Router) error { d.mu.Lock() defer d.mu.Unlock() + if strings.TrimSpace(r.DaemonMode) == "" { + r.DaemonMode = RouterDaemonModeConntrackAgent + } + res, err := d.db.Exec(` - INSERT INTO routers (name, addr, ssh_port, ssh_user, api_key, lan_subnet, status) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - r.Name, r.Addr, r.SSHPort, r.SSHUser, r.APIKey, r.LANSubnet, r.Status, + INSERT INTO routers (name, addr, ssh_port, ssh_user, api_key, lan_subnet, daemon_mode, linked_node_addr, status) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + r.Name, r.Addr, r.SSHPort, r.SSHUser, r.APIKey, r.LANSubnet, r.DaemonMode, r.LinkedNodeAddr, r.Status, ) if err != nil { return err @@ -49,26 +61,32 @@ func (d *Database) UpsertRouter(r *Router) error { d.mu.Lock() defer d.mu.Unlock() + if strings.TrimSpace(r.DaemonMode) == "" { + r.DaemonMode = RouterDaemonModeConntrackAgent + } + _, err := d.db.Exec(` - INSERT INTO routers (name, addr, ssh_port, ssh_user, api_key, lan_subnet, status) - VALUES (?, ?, ?, ?, ?, ?, ?) + INSERT INTO routers (name, addr, ssh_port, ssh_user, api_key, lan_subnet, daemon_mode, linked_node_addr, status) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(addr) DO UPDATE SET name=excluded.name, ssh_port=excluded.ssh_port, ssh_user=excluded.ssh_user, lan_subnet=excluded.lan_subnet, + daemon_mode=excluded.daemon_mode, + linked_node_addr=excluded.linked_node_addr, status=excluded.status, updated_at=datetime('now')`, - r.Name, r.Addr, r.SSHPort, r.SSHUser, r.APIKey, r.LANSubnet, r.Status, + r.Name, r.Addr, r.SSHPort, r.SSHUser, r.APIKey, r.LANSubnet, r.DaemonMode, r.LinkedNodeAddr, r.Status, ) if err != nil { return err } err = d.db.QueryRow(` - SELECT id, created_at, updated_at, api_key + SELECT id, created_at, updated_at, api_key, daemon_mode, linked_node_addr FROM routers WHERE addr = ?`, r.Addr). - Scan(&r.ID, &r.CreatedAt, &r.UpdatedAt, &r.APIKey) + Scan(&r.ID, &r.CreatedAt, &r.UpdatedAt, &r.APIKey, &r.DaemonMode, &r.LinkedNodeAddr) if err != nil { return err } @@ -82,9 +100,9 @@ func (d *Database) GetRouterByAPIKey(key string) (*Router, error) { var r Router err := d.db.QueryRow(` - SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, status, created_at, updated_at + SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, daemon_mode, linked_node_addr, status, created_at, updated_at FROM routers WHERE api_key = ?`, key). - Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.Status, &r.CreatedAt, &r.UpdatedAt) + Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.DaemonMode, &r.LinkedNodeAddr, &r.Status, &r.CreatedAt, &r.UpdatedAt) if err != nil { return nil, err } @@ -97,9 +115,24 @@ func (d *Database) GetRouterByAddr(addr string) (*Router, error) { var r Router err := d.db.QueryRow(` - SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, status, created_at, updated_at + SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, daemon_mode, linked_node_addr, status, created_at, updated_at FROM routers WHERE addr = ?`, addr). - Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.Status, &r.CreatedAt, &r.UpdatedAt) + Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.DaemonMode, &r.LinkedNodeAddr, &r.Status, &r.CreatedAt, &r.UpdatedAt) + if err != nil { + return nil, err + } + return &r, nil +} + +func (d *Database) GetRouterByLinkedNodeAddr(nodeAddr string) (*Router, error) { + d.mu.RLock() + defer d.mu.RUnlock() + + var r Router + err := d.db.QueryRow(` + SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, daemon_mode, linked_node_addr, status, created_at, updated_at + FROM routers WHERE linked_node_addr = ?`, nodeAddr). + Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.DaemonMode, &r.LinkedNodeAddr, &r.Status, &r.CreatedAt, &r.UpdatedAt) if err != nil { return nil, err } @@ -111,7 +144,7 @@ func (d *Database) GetRouters() ([]Router, error) { defer d.mu.RUnlock() rows, err := d.db.Query(` - SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, status, created_at, updated_at + SELECT id, name, addr, ssh_port, ssh_user, api_key, lan_subnet, daemon_mode, linked_node_addr, status, created_at, updated_at FROM routers ORDER BY created_at DESC`) if err != nil { return nil, err @@ -121,7 +154,7 @@ func (d *Database) GetRouters() ([]Router, error) { var routers []Router for rows.Next() { var r Router - if err := rows.Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.Status, &r.CreatedAt, &r.UpdatedAt); err != nil { + if err := rows.Scan(&r.ID, &r.Name, &r.Addr, &r.SSHPort, &r.SSHUser, &r.APIKey, &r.LANSubnet, &r.DaemonMode, &r.LinkedNodeAddr, &r.Status, &r.CreatedAt, &r.UpdatedAt); err != nil { return nil, err } routers = append(routers, r) @@ -140,6 +173,23 @@ func (d *Database) UpdateRouterStatus(addr, status string) error { return err } +func (d *Database) UpdateRouterRuntime(addr, daemonMode, linkedNodeAddr string) error { + d.mu.Lock() + defer d.mu.Unlock() + + if strings.TrimSpace(daemonMode) == "" { + daemonMode = RouterDaemonModeConntrackAgent + } + + _, err := d.db.Exec( + "UPDATE routers SET daemon_mode = ?, linked_node_addr = ?, updated_at = datetime('now') WHERE addr = ?", + daemonMode, + strings.TrimSpace(linkedNodeAddr), + addr, + ) + return err +} + func (d *Database) DeleteRouter(addr string) error { d.mu.Lock() defer d.mu.Unlock() diff --git a/internal/db/sqlite.go b/internal/db/sqlite.go index e80071a..2ffedd1 100644 --- a/internal/db/sqlite.go +++ b/internal/db/sqlite.go @@ -43,7 +43,6 @@ func (d *Database) Close() error { return d.db.Close() } - func (d *Database) migrate() error { // Legacy-safe pre-migrations for older DBs. // "duplicate column name" errors are expected and ignored. @@ -57,6 +56,8 @@ func (d *Database) migrate() error { "ALTER TABLE seen_flows ADD COLUMN source_rule_name TEXT NOT NULL DEFAULT ''", "ALTER TABLE seen_flows ADD COLUMN expires_at TEXT NOT NULL DEFAULT ''", "ALTER TABLE nodes ADD COLUMN source_type TEXT NOT NULL DEFAULT 'opensnitch'", + "ALTER TABLE routers ADD COLUMN daemon_mode TEXT NOT NULL DEFAULT 'conntrack-agent'", + "ALTER TABLE routers ADD COLUMN linked_node_addr TEXT NOT NULL DEFAULT ''", } for _, stmt := range legacyAlters { if _, err := d.db.Exec(stmt); err != nil { @@ -344,10 +345,13 @@ func (d *Database) migrate() error { ssh_user TEXT NOT NULL DEFAULT 'root', api_key TEXT NOT NULL UNIQUE, lan_subnet TEXT NOT NULL DEFAULT '', + daemon_mode TEXT NOT NULL DEFAULT 'conntrack-agent', + linked_node_addr TEXT NOT NULL DEFAULT '', status TEXT NOT NULL DEFAULT 'pending', created_at TEXT NOT NULL DEFAULT (datetime('now')), updated_at TEXT NOT NULL DEFAULT (datetime('now')) ); + CREATE INDEX IF NOT EXISTS idx_routers_linked_node_addr ON routers(linked_node_addr); ` _, err := d.db.Exec(schema) diff --git a/internal/grpcserver/identity.go b/internal/grpcserver/identity.go new file mode 100644 index 0000000..f555bc4 --- /dev/null +++ b/internal/grpcserver/identity.go @@ -0,0 +1,68 @@ +package grpcserver + +import ( + "context" + "database/sql" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +type contextKey string + +const ( + routerAPIKeyHeader = "x-router-api-key" + resolvedNodeKey contextKey = "resolved_node_addr" +) + +type wrappedServerStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedServerStream) Context() context.Context { + return w.ctx +} + +func withResolvedNodeAddr(ctx context.Context, addr string) context.Context { + addr = strings.TrimSpace(addr) + if addr == "" { + return ctx + } + return context.WithValue(ctx, resolvedNodeKey, addr) +} + +func resolvedNodeAddrFromContext(ctx context.Context) string { + addr, _ := ctx.Value(resolvedNodeKey).(string) + return strings.TrimSpace(addr) +} + +func (s *UIService) resolveRouterNodeContext(ctx context.Context) (context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx, nil + } + + values := md.Get(routerAPIKeyHeader) + if len(values) == 0 { + return ctx, nil + } + + apiKey := strings.TrimSpace(values[0]) + if apiKey == "" { + return nil, status.Error(codes.Unauthenticated, "missing router api key") + } + + router, err := s.db.GetRouterByAPIKey(apiKey) + if err != nil { + if err == sql.ErrNoRows { + return nil, status.Error(codes.Unauthenticated, "invalid router api key") + } + return nil, status.Errorf(codes.Internal, "resolve router api key: %v", err) + } + + return withResolvedNodeAddr(ctx, router.Addr), nil +} diff --git a/internal/grpcserver/server.go b/internal/grpcserver/server.go index 89c81d5..f66d86a 100644 --- a/internal/grpcserver/server.go +++ b/internal/grpcserver/server.go @@ -1,6 +1,7 @@ package grpcserver import ( + "context" "fmt" "log" "net" @@ -29,6 +30,20 @@ func New(service *UIService) *Server { s := grpc.NewServer( grpc.KeepaliveParams(kasp), grpc.KeepaliveEnforcementPolicy(kaep), + grpc.UnaryInterceptor(func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + resolvedCtx, err := service.resolveRouterNodeContext(ctx) + if err != nil { + return nil, err + } + return handler(resolvedCtx, req) + }), + grpc.StreamInterceptor(func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + resolvedCtx, err := service.resolveRouterNodeContext(stream.Context()) + if err != nil { + return err + } + return handler(srv, &wrappedServerStream{ServerStream: stream, ctx: resolvedCtx}) + }), ) pb.RegisterUIServer(s, service) diff --git a/internal/grpcserver/service.go b/internal/grpcserver/service.go index 0a9bbc4..4328a1a 100644 --- a/internal/grpcserver/service.go +++ b/internal/grpcserver/service.go @@ -49,6 +49,13 @@ func peerAddrFromCtx(ctx context.Context) string { return p.Addr.String() } +func nodeAddrFromCtx(ctx context.Context) string { + if resolved := resolvedNodeAddrFromContext(ctx); resolved != "" { + return resolved + } + return peerAddrFromCtx(ctx) +} + func normalizeEventTime(value string, unixnano int64) string { if unixnano > 0 { return ruleutil.FormatStoredTime(time.Unix(0, unixnano)) @@ -174,13 +181,14 @@ func formatAlertBody(alert *pb.Alert) string { // Subscribe is called when a daemon first connects. func (s *UIService) Subscribe(ctx context.Context, config *pb.ClientConfig) (*pb.ClientConfig, error) { peerAddr := peerAddrFromCtx(ctx) + nodeAddr := nodeAddrFromCtx(ctx) log.Printf("[grpc] Subscribe from %s (name: %s, version: %s, rules: %d)", peerAddr, config.GetName(), config.GetVersion(), len(config.GetRules())) - s.nodes.AddNode(peerAddr, config) + s.nodes.AddNode(nodeAddr, config) s.db.UpsertNode(&db.Node{ - Addr: peerAddr, + Addr: nodeAddr, Hostname: config.GetName(), DaemonVersion: config.GetVersion(), Status: db.NodeStatusOnline, @@ -192,29 +200,29 @@ func (s *UIService) Subscribe(ctx context.Context, config *pb.ClientConfig) (*pb snapshotRules := make([]*db.DBRule, 0, len(config.GetRules())) observedAt := time.Now() for _, r := range config.GetRules() { - dbRule, err := ruleutil.ProtoToDBRule(peerAddr, observedAt, r) + dbRule, err := ruleutil.ProtoToDBRule(nodeAddr, observedAt, r) if err != nil { - log.Printf("[grpc] Failed to convert rule %q from %s: %v", r.GetName(), peerAddr, err) + log.Printf("[grpc] Failed to convert rule %q from %s: %v", r.GetName(), nodeAddr, err) continue } if s.templateSync != nil { if err := s.templateSync.DecorateStoredRule(dbRule); err != nil { - log.Printf("[grpc] Failed to decorate stored rule %q from %s: %v", r.GetName(), peerAddr, err) + log.Printf("[grpc] Failed to decorate stored rule %q from %s: %v", r.GetName(), nodeAddr, err) } } snapshotRules = append(snapshotRules, dbRule) } - if err := s.db.ReplaceNodeRulesSnapshot(peerAddr, snapshotRules); err != nil { + if err := s.db.ReplaceNodeRulesSnapshot(nodeAddr, snapshotRules); err != nil { return nil, err } if s.templateSync != nil { - if err := s.templateSync.ReconcileNode(peerAddr); err != nil { - log.Printf("[grpc] Failed to reconcile templates for %s: %v", peerAddr, err) + if err := s.templateSync.ReconcileNode(nodeAddr); err != nil { + log.Printf("[grpc] Failed to reconcile templates for %s: %v", nodeAddr, err) } } s.hub.BroadcastEvent(ws.EventNodeConnected, map[string]interface{}{ - "addr": peerAddr, + "addr": nodeAddr, "hostname": config.GetName(), "version": config.GetVersion(), }) @@ -224,9 +232,9 @@ func (s *UIService) Subscribe(ctx context.Context, config *pb.ClientConfig) (*pb // Ping is the heartbeat — daemon sends stats every ~1s func (s *UIService) Ping(ctx context.Context, req *pb.PingRequest) (*pb.PingReply, error) { - peerAddr := peerAddrFromCtx(ctx) + nodeAddr := nodeAddrFromCtx(ctx) - node := s.nodes.GetNode(peerAddr) + node := s.nodes.GetNode(nodeAddr) if node == nil { return &pb.PingReply{Id: req.Id}, nil } @@ -236,7 +244,7 @@ func (s *UIService) Ping(ctx context.Context, req *pb.PingRequest) (*pb.PingRepl stats := req.GetStats() if stats != nil { s.db.UpsertNode(&db.Node{ - Addr: peerAddr, + Addr: nodeAddr, Hostname: node.Hostname, DaemonVersion: stats.DaemonVersion, DaemonUptime: int64(stats.Uptime), @@ -252,29 +260,29 @@ func (s *UIService) Ping(ctx context.Context, req *pb.PingRequest) (*pb.PingRepl if evt.Connection == nil { continue } - s.persistConnection(peerAddr, evt.Connection, evt.Rule, normalizeEventTime(evt.Time, evt.Unixnano)) + s.persistConnection(nodeAddr, evt.Connection, evt.Rule, normalizeEventTime(evt.Time, evt.Unixnano)) } // Update stats tables for k, v := range stats.ByHost { - s.db.UpsertStat("hosts", k, peerAddr, int64(v)) + s.db.UpsertStat("hosts", k, nodeAddr, int64(v)) } for k, v := range stats.ByExecutable { - s.db.UpsertStat("procs", k, peerAddr, int64(v)) + s.db.UpsertStat("procs", k, nodeAddr, int64(v)) } for k, v := range stats.ByAddress { - s.db.UpsertStat("addrs", k, peerAddr, int64(v)) + s.db.UpsertStat("addrs", k, nodeAddr, int64(v)) } for k, v := range stats.ByPort { - s.db.UpsertStat("ports", k, peerAddr, int64(v)) + s.db.UpsertStat("ports", k, nodeAddr, int64(v)) } for k, v := range stats.ByUid { - s.db.UpsertStat("users", k, peerAddr, int64(v)) + s.db.UpsertStat("users", k, nodeAddr, int64(v)) } // Broadcast stats to browsers s.hub.BroadcastEvent(ws.EventStatsUpdate, map[string]interface{}{ - "node": peerAddr, + "node": nodeAddr, "daemon_version": stats.DaemonVersion, "uptime": stats.Uptime, "rules": stats.Rules, @@ -301,9 +309,10 @@ func (s *UIService) Ping(ctx context.Context, req *pb.PingRequest) (*pb.PingRepl // Pipeline: blocklist check → trust list check → node mode check → prompt user. func (s *UIService) AskRule(ctx context.Context, conn *pb.Connection) (*pb.Rule, error) { peerAddr := peerAddrFromCtx(ctx) + nodeAddr := nodeAddrFromCtx(ctx) log.Printf("[grpc] AskRule from %s: %s -> %s:%d (%s)", peerAddr, conn.ProcessPath, conn.DstHost, conn.DstPort, conn.Protocol) - seenFlowKey, learningKey, trackSeenFlow := buildSeenFlowKey(peerAddr, conn) + seenFlowKey, learningKey, trackSeenFlow := buildSeenFlowKey(nodeAddr, conn) // 1. Check blocklist — auto-deny blocked domains (even in silent_allow mode) if conn.DstHost != "" && s.db.IsDomainBlocked(conn.DstHost) { @@ -318,12 +327,12 @@ func (s *UIService) AskRule(ctx context.Context, conn *pb.Connection) (*pb.Rule, Data: conn.DstHost, }, } - s.persistConnection(peerAddr, conn, rule, "") + s.persistConnection(nodeAddr, conn, rule, "") return rule, nil } // 2. Check process trust list - trustLevel := s.db.GetProcessTrustLevel(peerAddr, conn.ProcessPath) + trustLevel := s.db.GetProcessTrustLevel(nodeAddr, conn.ProcessPath) switch trustLevel { case db.TrustLevelTrusted: log.Printf("[grpc] AskRule: process %s trusted, auto-allow", conn.ProcessPath) @@ -337,24 +346,24 @@ func (s *UIService) AskRule(ctx context.Context, conn *pb.Connection) (*pb.Rule, Data: conn.ProcessPath, }, } - s.persistConnection(peerAddr, conn, rule, "") + s.persistConnection(nodeAddr, conn, rule, "") return rule, nil case db.TrustLevelUntrusted: log.Printf("[grpc] AskRule: process %s untrusted, forcing prompt", conn.ProcessPath) - result, err := s.prompter.AskUser(peerAddr, conn) + result, err := s.prompter.AskUser(nodeAddr, conn) if err != nil { return nil, err } s.persistPromptDecision(seenFlowKey, result, trackSeenFlow) - s.persistConnection(peerAddr, conn, result.Rule, "") + s.persistConnection(nodeAddr, conn, result.Rule, "") return result.Rule, nil } // 3. Check node mode — auto-allow or auto-deny without prompting - mode, _ := s.db.GetNodeMode(peerAddr) + mode, _ := s.db.GetNodeMode(nodeAddr) switch mode { case db.ModeSilentAllow: - log.Printf("[grpc] AskRule: silent_allow for node %s", peerAddr) + log.Printf("[grpc] AskRule: silent_allow for node %s", nodeAddr) rule := &pb.Rule{ Name: "silent-allow", Action: "allow", @@ -365,10 +374,10 @@ func (s *UIService) AskRule(ctx context.Context, conn *pb.Connection) (*pb.Rule, Data: conn.DstHost, }, } - s.persistConnection(peerAddr, conn, rule, "") + s.persistConnection(nodeAddr, conn, rule, "") return rule, nil case db.ModeSilentDeny: - log.Printf("[grpc] AskRule: silent_deny for node %s", peerAddr) + log.Printf("[grpc] AskRule: silent_deny for node %s", nodeAddr) rule := &pb.Rule{ Name: "silent-deny", Action: "deny", @@ -379,7 +388,7 @@ func (s *UIService) AskRule(ctx context.Context, conn *pb.Connection) (*pb.Rule, Data: conn.DstHost, }, } - s.persistConnection(peerAddr, conn, rule, "") + s.persistConnection(nodeAddr, conn, rule, "") return rule, nil } @@ -388,30 +397,30 @@ func (s *UIService) AskRule(ctx context.Context, conn *pb.Connection) (*pb.Rule, now := time.Now() flow, err := s.db.GetSeenFlow(seenFlowKey) if err != nil { - log.Printf("[grpc] AskRule: seen flow lookup failed for %s: %v", peerAddr, err) + log.Printf("[grpc] AskRule: seen flow lookup failed for %s: %v", nodeAddr, err) } else if flow != nil { if flow.IsExpired(now) { if err := s.db.DeleteSeenFlow(seenFlowKey); err != nil { - log.Printf("[grpc] AskRule: failed to delete expired seen flow for %s: %v", peerAddr, err) + log.Printf("[grpc] AskRule: failed to delete expired seen flow for %s: %v", nodeAddr, err) } } else { expiresAt, _ := flow.ExpiryTime() log.Printf("[grpc] AskRule: reusing remembered %s decision for %s -> %s:%d (%s)", flow.Action, conn.ProcessPath, flow.Destination, flow.DstPort, flow.Protocol) if err := s.db.UpsertSeenFlow(seenFlowKey, flow.Action, flow.SourceRuleName, now, expiresAt); err != nil { - log.Printf("[grpc] AskRule: failed to refresh seen flow for %s: %v", peerAddr, err) + log.Printf("[grpc] AskRule: failed to refresh seen flow for %s: %v", nodeAddr, err) } return ruleutil.BuildSeenFlowRule(learningKey, flow.Action), nil } } } - result, err := s.prompter.AskUser(peerAddr, conn) + result, err := s.prompter.AskUser(nodeAddr, conn) if err != nil { return nil, err } s.persistPromptDecision(seenFlowKey, result, trackSeenFlow) - s.persistConnection(peerAddr, conn, result.Rule, "") + s.persistConnection(nodeAddr, conn, result.Rule, "") return result.Rule, nil } @@ -471,17 +480,14 @@ func seenFlowRetention(rule *pb.Rule, now time.Time) (time.Time, bool) { // Notifications is the bidirectional streaming RPC. func (s *UIService) Notifications(stream pb.UI_NotificationsServer) error { - peerAddr := "" - if p, ok := peer.FromContext(stream.Context()); ok { - peerAddr = p.Addr.String() - } + nodeAddr := nodeAddrFromCtx(stream.Context()) - node := s.nodes.GetNode(peerAddr) + node := s.nodes.GetNode(nodeAddr) if node == nil { - return fmt.Errorf("node %s not registered", peerAddr) + return fmt.Errorf("node %s not registered", nodeAddr) } - log.Printf("[grpc] Notifications stream opened for %s", peerAddr) + log.Printf("[grpc] Notifications stream opened for %s", nodeAddr) // Read replies from daemon in a goroutine errChan := make(chan error, 1) @@ -496,7 +502,7 @@ func (s *UIService) Notifications(stream pb.UI_NotificationsServer) error { errChan <- err return } - log.Printf("[grpc] NotificationReply from %s: id=%d code=%v", peerAddr, reply.Id, reply.Code) + log.Printf("[grpc] NotificationReply from %s: id=%d code=%v", nodeAddr, reply.Id, reply.Code) } }() @@ -505,23 +511,23 @@ func (s *UIService) Notifications(stream pb.UI_NotificationsServer) error { select { case notif := <-node.NotifyChan: if notif == nil || notif.Type == -1 { - log.Printf("[grpc] Notifications stream closing for %s (sentinel)", peerAddr) + log.Printf("[grpc] Notifications stream closing for %s (sentinel)", nodeAddr) return nil } if err := stream.Send(notif); err != nil { - log.Printf("[grpc] Error sending notification to %s: %v", peerAddr, err) + log.Printf("[grpc] Error sending notification to %s: %v", nodeAddr, err) return err } - log.Printf("[grpc] Sent notification to %s: type=%v", peerAddr, notif.Type) + log.Printf("[grpc] Sent notification to %s: type=%v", nodeAddr, notif.Type) case err := <-errChan: - log.Printf("[grpc] Notifications stream ended for %s: %v", peerAddr, err) - s.handleNodeDisconnect(peerAddr) + log.Printf("[grpc] Notifications stream ended for %s: %v", nodeAddr, err) + s.handleNodeDisconnect(nodeAddr) return err case <-stream.Context().Done(): - log.Printf("[grpc] Notifications context done for %s", peerAddr) - s.handleNodeDisconnect(peerAddr) + log.Printf("[grpc] Notifications context done for %s", nodeAddr) + s.handleNodeDisconnect(nodeAddr) return stream.Context().Err() } } @@ -538,13 +544,14 @@ func (s *UIService) handleNodeDisconnect(addr string) { // PostAlert is called when the daemon sends an alert func (s *UIService) PostAlert(ctx context.Context, alert *pb.Alert) (*pb.MsgResponse, error) { peerAddr := peerAddrFromCtx(ctx) + nodeAddr := nodeAddrFromCtx(ctx) log.Printf("[grpc] PostAlert from %s: type=%v priority=%v what=%v", peerAddr, alert.Type, alert.Priority, alert.What) body := formatAlertBody(alert) s.db.InsertAlert(&db.DBAlert{ Time: time.Now().Format("2006-01-02 15:04:05"), - Node: peerAddr, + Node: nodeAddr, Type: int(alert.Type), Action: int(alert.Action), Priority: int(alert.Priority), @@ -554,7 +561,7 @@ func (s *UIService) PostAlert(ctx context.Context, alert *pb.Alert) (*pb.MsgResp }) s.hub.BroadcastEvent(ws.EventNewAlert, map[string]interface{}{ - "node": peerAddr, + "node": nodeAddr, "type": alert.Type, "priority": alert.Priority, "what": alert.What, diff --git a/internal/grpcserver/service_test.go b/internal/grpcserver/service_test.go index b99a35d..4b9d2ff 100644 --- a/internal/grpcserver/service_test.go +++ b/internal/grpcserver/service_test.go @@ -8,7 +8,10 @@ import ( "testing" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" "github.com/bilalbayram/opensnitch-web/internal/db" "github.com/bilalbayram/opensnitch-web/internal/nodemanager" @@ -76,6 +79,11 @@ func askRuleContext(addr string) context.Context { return peer.NewContext(context.Background(), &peer.Peer{Addr: serviceTestAddr(addr)}) } +func routerMetadataContext(addr, apiKey string) context.Context { + ctx := askRuleContext(addr) + return metadata.NewIncomingContext(ctx, metadata.Pairs(routerAPIKeyHeader, apiKey)) +} + func testConnection() *pb.Connection { return &pb.Connection{ ProcessPath: testProcessPath, @@ -115,6 +123,78 @@ func replyRuleWithDuration(action, duration string) *pb.Rule { } } +func TestSubscribeUsesRouterAPIKeyIdentity(t *testing.T) { + env := newServiceTestEnv(t, 1) + + if err := env.database.InsertRouter(&db.Router{ + Name: "router-a", + Addr: "router-a", + SSHPort: 22, + SSHUser: "root", + APIKey: "router-secret", + LANSubnet: "192.168.1.0/24", + DaemonMode: db.RouterDaemonModeRouterDaemon, + LinkedNodeAddr: "router-a", + Status: "active", + }); err != nil { + t.Fatalf("insert router: %v", err) + } + + ctx, err := env.service.resolveRouterNodeContext(routerMetadataContext("10.0.0.20:50051", "router-secret")) + if err != nil { + t.Fatalf("resolve router identity: %v", err) + } + + if _, err := env.service.Subscribe(ctx, &pb.ClientConfig{ + Name: "openwrt-a", + Version: "opensnitchd-router/test", + }); err != nil { + t.Fatalf("subscribe router-daemon: %v", err) + } + + if env.service.nodes.GetNode("router-a") == nil { + t.Fatal("expected router identity to resolve to router-a") + } + if env.service.nodes.GetNode("10.0.0.20:50051") != nil { + t.Fatal("expected peer address not to be used as node identity") + } + + node, err := env.database.GetNode("router-a") + if err != nil { + t.Fatalf("load resolved node: %v", err) + } + if node.Hostname != "openwrt-a" { + t.Fatalf("expected resolved node hostname to be updated, got %+v", node) + } +} + +func TestSubscribeRejectsInvalidRouterAPIKey(t *testing.T) { + env := newServiceTestEnv(t, 1) + + _, err := env.service.resolveRouterNodeContext(routerMetadataContext("10.0.0.20:50051", "bad-key")) + if err == nil { + t.Fatal("expected invalid router api key to be rejected") + } + if status.Code(err) != codes.Unauthenticated { + t.Fatalf("expected unauthenticated, got %v", status.Code(err)) + } +} + +func TestSubscribeWithoutRouterMetadataKeepsLegacyPeerIdentity(t *testing.T) { + env := newServiceTestEnv(t, 1) + + if _, err := env.service.Subscribe(askRuleContext("desktop-a:50051"), &pb.ClientConfig{ + Name: "desktop-a", + Version: "1.0.0", + }); err != nil { + t.Fatalf("subscribe legacy node: %v", err) + } + + if env.service.nodes.GetNode("desktop-a:50051") == nil { + t.Fatal("expected peer identity for legacy daemon") + } +} + func TestAskRulePromptsForNewFlowAndPersistsExplicitDecision(t *testing.T) { env := newServiceTestEnv(t, 1) nodeAddr := "node-a" diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index 6a36702..7b60c76 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -141,3 +141,9 @@ func (p *Prompter) GetPending() []*PendingPrompt { } return prompts } + +func (p *Prompter) GetPendingPrompt(id string) *PendingPrompt { + p.mu.RLock() + defer p.mu.RUnlock() + return p.pending[id] +} diff --git a/internal/router/capabilities.go b/internal/router/capabilities.go new file mode 100644 index 0000000..bc98a91 --- /dev/null +++ b/internal/router/capabilities.go @@ -0,0 +1,143 @@ +package router + +import ( + "fmt" + "strconv" + "strings" +) + +const ( + minKernelVersionMajor = 5 + minKernelVersionMinor = 4 + minRAMMB = 128 + minOverlayFreeMB = 20 +) + +type RouterCapabilities struct { + OpenWrt bool `json:"openwrt"` + Arch string `json:"arch"` + KernelVersion string `json:"kernel_version"` + KernelSupported bool `json:"kernel_supported"` + RAMMB int `json:"ram_mb"` + RAMSupported bool `json:"ram_supported"` + OverlayFreeMB int `json:"overlay_free_mb"` + OverlaySupported bool `json:"overlay_supported"` + HasNFT bool `json:"has_nft"` + HasConntrack bool `json:"has_conntrack"` + HasOpkg bool `json:"has_opkg"` + Eligible bool `json:"eligible"` + IneligibleReason string `json:"ineligible_reason"` +} + +type CapabilityCheckResult struct { + Capabilities *RouterCapabilities `json:"capabilities"` + Steps []ProvisionStep `json:"steps"` +} + +func CheckCapabilities(client remoteClient) (*RouterCapabilities, error) { + caps := &RouterCapabilities{} + + openwrtOut, err := client.Run("cat /etc/openwrt_release 2>/dev/null") + if err != nil { + return nil, fmt.Errorf("check openwrt release: %w", err) + } + caps.OpenWrt = strings.Contains(openwrtOut, "DISTRIB") + if !caps.OpenWrt { + caps.IneligibleReason = "router-daemon v1 only supports OpenWrt" + return finalizeCapabilities(caps), nil + } + + if caps.Arch, err = runTrimmed(client, "uname -m"); err != nil { + return nil, fmt.Errorf("check architecture: %w", err) + } + if caps.KernelVersion, err = runTrimmed(client, "uname -r"); err != nil { + return nil, fmt.Errorf("check kernel version: %w", err) + } + caps.KernelSupported = kernelAtLeast(caps.KernelVersion, minKernelVersionMajor, minKernelVersionMinor) + + memKB, err := runTrimmed(client, "awk '/MemTotal/ {print $2}' /proc/meminfo") + if err != nil { + return nil, fmt.Errorf("check memory: %w", err) + } + if parsed, err := strconv.Atoi(memKB); err == nil { + caps.RAMMB = parsed / 1024 + } + caps.RAMSupported = caps.RAMMB >= minRAMMB + + overlayKB, err := runTrimmed(client, "df -k /overlay 2>/dev/null | awk 'NR==2 {print $4}'") + if err != nil { + return nil, fmt.Errorf("check overlay free space: %w", err) + } + if parsed, err := strconv.Atoi(overlayKB); err == nil { + caps.OverlayFreeMB = parsed / 1024 + } + caps.OverlaySupported = caps.OverlayFreeMB >= minOverlayFreeMB + + caps.HasNFT = commandExists(client, "nft") + caps.HasConntrack = commandExists(client, "conntrack") + caps.HasOpkg = commandExists(client, "opkg") + + return finalizeCapabilities(caps), nil +} + +func finalizeCapabilities(caps *RouterCapabilities) *RouterCapabilities { + switch { + case !caps.OpenWrt: + caps.IneligibleReason = "router-daemon v1 only supports OpenWrt" + case caps.Arch != "aarch64": + caps.IneligibleReason = "router-daemon v1 only supports aarch64 OpenWrt targets" + case !caps.KernelSupported: + caps.IneligibleReason = "router-daemon v1 requires kernel 5.4 or newer" + case !caps.RAMSupported: + caps.IneligibleReason = "router-daemon v1 requires at least 128MB RAM" + case !caps.OverlaySupported: + caps.IneligibleReason = "router-daemon v1 requires at least 20MB free overlay space" + case !caps.HasNFT: + caps.IneligibleReason = "router-daemon v1 requires nft" + case !caps.HasConntrack: + caps.IneligibleReason = "router-daemon v1 requires conntrack" + case !caps.HasOpkg: + caps.IneligibleReason = "router-daemon v1 requires opkg" + default: + caps.Eligible = true + caps.IneligibleReason = "" + } + + return caps +} + +func kernelAtLeast(value string, wantMajor, wantMinor int) bool { + fields := strings.FieldsFunc(strings.TrimSpace(value), func(r rune) bool { + return r == '.' || r == '-' + }) + if len(fields) < 2 { + return false + } + + major, err := strconv.Atoi(fields[0]) + if err != nil { + return false + } + minor, err := strconv.Atoi(fields[1]) + if err != nil { + return false + } + + if major != wantMajor { + return major > wantMajor + } + return minor >= wantMinor +} + +func commandExists(client remoteClient, name string) bool { + out, err := client.Run("command -v " + name + " >/dev/null 2>&1 && echo YES || echo NO") + if err != nil { + return false + } + return strings.TrimSpace(out) == "YES" +} + +func runTrimmed(client remoteClient, cmd string) (string, error) { + out, err := client.Run(cmd) + return strings.TrimSpace(out), err +} diff --git a/internal/router/daemon_initd.sh b/internal/router/daemon_initd.sh new file mode 100644 index 0000000..768698a --- /dev/null +++ b/internal/router/daemon_initd.sh @@ -0,0 +1,11 @@ +#!/bin/sh /etc/rc.common + +START=95 +USE_PROCD=1 + +start_service() { + procd_open_instance + procd_set_param command /usr/bin/opensnitchd-router -config /etc/opensnitchd-router/config.json + procd_set_param respawn + procd_close_instance +} diff --git a/internal/router/daemon_templates.go b/internal/router/daemon_templates.go new file mode 100644 index 0000000..e2bb733 --- /dev/null +++ b/internal/router/daemon_templates.go @@ -0,0 +1,30 @@ +package router + +import ( + _ "embed" + "encoding/json" +) + +//go:embed daemon_initd.sh +var daemonInitdScript string + +type DaemonConfig struct { + GRPCAddr string `json:"grpc_addr"` + APIKey string `json:"api_key"` + NodeName string `json:"node_name"` + DefaultAction string `json:"default_action"` + PollIntervalMS int `json:"poll_interval_ms"` + FirewallBackend string `json:"firewall_backend"` +} + +func RenderDaemonConfig(cfg DaemonConfig) (string, error) { + payload, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return "", err + } + return string(payload) + "\n", nil +} + +func DaemonInitdScript() string { + return daemonInitdScript +} diff --git a/internal/router/provisioner.go b/internal/router/provisioner.go index eab6ab4..bd90fca 100644 --- a/internal/router/provisioner.go +++ b/internal/router/provisioner.go @@ -6,10 +6,15 @@ import ( "fmt" "log" "net" + "net/url" + "os" + "path/filepath" "strings" "time" + "github.com/bilalbayram/opensnitch-web/internal/config" "github.com/bilalbayram/opensnitch-web/internal/db" + "github.com/bilalbayram/opensnitch-web/internal/nodemanager" "golang.org/x/crypto/ssh" ) @@ -17,6 +22,7 @@ type remoteClient interface { Close() error Run(cmd string) (string, error) WriteFile(path, content string) error + WriteBinary(path string, content []byte, mode os.FileMode) error } type sshRemoteClient struct { @@ -35,10 +41,17 @@ func (c *sshRemoteClient) WriteFile(path, content string) error { return writeSSHRemoteFile(c.client, path, content) } +func (c *sshRemoteClient) WriteBinary(path string, content []byte, mode os.FileMode) error { + return writeSSHRemoteBinary(c.client, path, content, mode) +} + type Provisioner struct { - db *db.Database - dial func(addr string, port int, user, pass, key string) (remoteClient, error) - sleep func(time.Duration) + db *db.Database + cfg *config.Config + nodes *nodemanager.Manager + dial func(addr string, port int, user, pass, key string) (remoteClient, error) + sleep func(time.Duration) + readFile func(string) ([]byte, error) } type ConnectRequest struct { @@ -53,10 +66,23 @@ type ConnectRequest struct { } type ProvisionResult struct { - Router *db.Router `json:"router"` - Steps []ProvisionStep `json:"steps"` - ServerURL string `json:"server_url"` - ServerURLSource string `json:"server_url_source"` + Router *db.Router `json:"router"` + Capabilities *RouterCapabilities `json:"capabilities,omitempty"` + Steps []ProvisionStep `json:"steps"` + ServerURL string `json:"server_url"` + ServerURLSource string `json:"server_url_source"` +} + +type DaemonRequest struct { + Addr string `json:"-"` + SSHPort int `json:"ssh_port"` + SSHUser string `json:"ssh_user"` + SSHPass string `json:"ssh_pass"` + SSHKey string `json:"ssh_key,omitempty"` + NodeName string `json:"node_name"` + DefaultAction string `json:"default_action"` + PollIntervalMS int `json:"poll_interval_ms"` + FirewallBackend string `json:"firewall_backend"` } type ProvisionStep struct { @@ -65,14 +91,29 @@ type ProvisionStep struct { Message string `json:"message"` } +const ( + routerDaemonBinaryLocalPath = "bin/opensnitchd-router-linux-arm64" + routerDaemonBinaryRemotePath = "/usr/bin/opensnitchd-router" + routerDaemonConfigDir = "/etc/opensnitchd-router" + routerDaemonConfigPath = "/etc/opensnitchd-router/config.json" + routerDaemonInitdPath = "/etc/init.d/opensnitchd-router" +) + func NewProvisioner(database *db.Database) *Provisioner { return &Provisioner{ - db: database, - dial: sshDial, - sleep: time.Sleep, + db: database, + dial: sshDial, + sleep: time.Sleep, + readFile: os.ReadFile, } } +func (p *Provisioner) WithRuntime(nodes *nodemanager.Manager, cfg *config.Config) *Provisioner { + p.nodes = nodes + p.cfg = cfg + return p +} + func (p *Provisioner) Provision(ctx context.Context, req ConnectRequest) (*ProvisionResult, error) { var steps []ProvisionStep @@ -200,13 +241,15 @@ func (p *Provisioner) Provision(ctx context.Context, req ConnectRequest) (*Provi // 8. Store router record router := &db.Router{ - Name: req.Name, - Addr: req.Addr, - SSHPort: req.SSHPort, - SSHUser: req.SSHUser, - APIKey: apiKey, - LANSubnet: lanPrefix, - Status: "active", + Name: req.Name, + Addr: req.Addr, + SSHPort: req.SSHPort, + SSHUser: req.SSHUser, + APIKey: apiKey, + LANSubnet: lanPrefix, + DaemonMode: db.RouterDaemonModeConntrackAgent, + LinkedNodeAddr: "", + Status: "active", } if err := p.db.UpsertRouter(router); err != nil { addStep("register", "error", fmt.Sprintf("Failed to save router: %v", err)) @@ -259,6 +302,221 @@ func (p *Provisioner) Deprovision(ctx context.Context, addr string, sshPort int, return steps, nil } +func (p *Provisioner) CheckCapabilities(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) (*CapabilityCheckResult, error) { + _ = ctx + + steps := []ProvisionStep{} + addStep := func(step, status, message string) { + steps = append(steps, ProvisionStep{Step: step, Status: status, Message: message}) + log.Printf("[router] %s: %s — %s", step, status, message) + } + + client, err := p.dial(addr, sshPort, sshUser, sshPass, sshKey) + if err != nil { + addStep("connect", "error", fmt.Sprintf("SSH connection failed: %v", err)) + return &CapabilityCheckResult{Steps: steps}, err + } + defer client.Close() + addStep("connect", "done", fmt.Sprintf("Connected to %s:%d", addr, sshPort)) + + caps, err := CheckCapabilities(client) + if err != nil { + addStep("capabilities", "error", err.Error()) + return &CapabilityCheckResult{Capabilities: caps, Steps: steps}, err + } + + status := "done" + message := "Router is eligible for router-daemon v1" + if !caps.Eligible { + status = "error" + message = caps.IneligibleReason + } + addStep("capabilities", status, message) + + return &CapabilityCheckResult{Capabilities: caps, Steps: steps}, nil +} + +func (p *Provisioner) ProvisionDaemon(ctx context.Context, req DaemonRequest) (*ProvisionResult, error) { + p.applyDaemonDefaults(&req) + + steps := []ProvisionStep{} + addStep := func(step, status, message string) { + steps = append(steps, ProvisionStep{Step: step, Status: status, Message: message}) + log.Printf("[router] %s: %s — %s", step, status, message) + } + + existing, err := p.db.GetRouterByAddr(req.Addr) + if err != nil { + addStep("lookup", "error", "Router must be connected before upgrading to router-daemon") + return &ProvisionResult{Steps: steps}, err + } + + client, err := p.dial(req.Addr, req.SSHPort, req.SSHUser, req.SSHPass, req.SSHKey) + if err != nil { + addStep("connect", "error", fmt.Sprintf("SSH connection failed: %v", err)) + return &ProvisionResult{Steps: steps}, fmt.Errorf("ssh dial: %w", err) + } + defer client.Close() + addStep("connect", "done", fmt.Sprintf("Connected to %s:%d", req.Addr, req.SSHPort)) + + caps, err := CheckCapabilities(client) + if err != nil { + addStep("capabilities", "error", err.Error()) + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if !caps.Eligible { + addStep("capabilities", "error", caps.IneligibleReason) + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, fmt.Errorf(caps.IneligibleReason) + } + addStep("capabilities", "done", "Router is eligible for router-daemon v1") + + grpcAddr, grpcSource, err := p.resolveGRPCPublicAddr(req.Addr) + if err != nil { + addStep("config", "error", err.Error()) + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + addStep("config", "done", fmt.Sprintf("Using public gRPC address %s (%s)", grpcAddr, grpcSource)) + + binary, err := p.readFile(filepath.Clean(routerDaemonBinaryLocalPath)) + if err != nil { + addStep("deploy", "error", fmt.Sprintf("Missing local router-daemon binary at %s", routerDaemonBinaryLocalPath)) + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + + daemonConfig, err := RenderDaemonConfig(DaemonConfig{ + GRPCAddr: grpcAddr, + APIKey: existing.APIKey, + NodeName: req.NodeName, + DefaultAction: req.DefaultAction, + PollIntervalMS: req.PollIntervalMS, + FirewallBackend: req.FirewallBackend, + }) + if err != nil { + addStep("config", "error", fmt.Sprintf("Render daemon config: %v", err)) + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + + stoppedLegacy := false + if out, err := client.Run("if [ -x /etc/init.d/conntrack-agent ]; then /etc/init.d/conntrack-agent stop && /etc/init.d/conntrack-agent disable; else echo ABSENT; fi"); err != nil { + addStep("stop_legacy", "error", strings.TrimSpace(out)) + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, fmt.Errorf("stop conntrack-agent: %w", err) + } else { + if strings.TrimSpace(out) == "ABSENT" { + addStep("stop_legacy", "done", "Legacy conntrack-agent already absent") + } else { + stoppedLegacy = true + addStep("stop_legacy", "done", "Stopped and disabled legacy conntrack-agent") + } + } + + rollbackLegacy := func(reason string) { + if !stoppedLegacy { + return + } + if _, rollbackErr := client.Run("if [ -x /etc/init.d/conntrack-agent ]; then /etc/init.d/conntrack-agent enable && /etc/init.d/conntrack-agent start; fi"); rollbackErr != nil { + addStep("rollback", "warning", fmt.Sprintf("Failed to restore conntrack-agent after %s: %v", reason, rollbackErr)) + return + } + addStep("rollback", "done", fmt.Sprintf("Restored conntrack-agent after %s", reason)) + } + + if _, err := client.Run("mkdir -p " + shellQuote(routerDaemonConfigDir)); err != nil { + addStep("deploy", "error", fmt.Sprintf("Create daemon config dir: %v", err)) + rollbackLegacy("directory creation failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if err := client.WriteBinary(routerDaemonBinaryRemotePath, binary, 0755); err != nil { + addStep("deploy", "error", fmt.Sprintf("Upload router-daemon binary: %v", err)) + rollbackLegacy("binary upload failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if err := client.WriteFile(routerDaemonConfigPath, daemonConfig); err != nil { + addStep("deploy", "error", fmt.Sprintf("Write daemon config: %v", err)) + rollbackLegacy("config upload failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if err := client.WriteFile(routerDaemonInitdPath, DaemonInitdScript()); err != nil { + addStep("deploy", "error", fmt.Sprintf("Write daemon init script: %v", err)) + rollbackLegacy("init script upload failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if _, err := client.Run("chmod +x " + shellQuote(routerDaemonInitdPath)); err != nil { + addStep("deploy", "error", fmt.Sprintf("chmod daemon init script: %v", err)) + rollbackLegacy("init script chmod failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + addStep("deploy", "done", "opensnitchd-router files deployed") + + if _, err := client.Run(shellQuote(routerDaemonInitdPath) + " enable"); err != nil { + addStep("start", "error", fmt.Sprintf("Enable router-daemon service: %v", err)) + rollbackLegacy("enable failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if _, err := client.Run(shellQuote(routerDaemonInitdPath) + " start"); err != nil { + addStep("start", "error", fmt.Sprintf("Start router-daemon service: %v", err)) + rollbackLegacy("start failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + addStep("start", "done", "router-daemon service started") + + if err := p.waitForManagedNode(ctx, req.Addr, 20*time.Second); err != nil { + addStep("verify", "error", fmt.Sprintf("router-daemon did not subscribe as %s: %v", req.Addr, err)) + if cleanupErr := p.removeDaemonArtifacts(client); cleanupErr != nil { + addStep("rollback", "warning", fmt.Sprintf("Failed to remove router-daemon after subscribe failure: %v", cleanupErr)) + } else { + addStep("rollback", "done", "Removed router-daemon after failed verification") + } + rollbackLegacy("verification failure") + _ = p.db.UpdateRouterRuntime(existing.Addr, db.RouterDaemonModeConntrackAgent, "") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + addStep("verify", "done", fmt.Sprintf("router-daemon subscribed as %s", req.Addr)) + + if err := p.db.UpdateRouterRuntime(existing.Addr, db.RouterDaemonModeRouterDaemon, existing.Addr); err != nil { + addStep("register", "error", fmt.Sprintf("Update router runtime: %v", err)) + if cleanupErr := p.removeDaemonArtifacts(client); cleanupErr != nil { + addStep("rollback", "warning", fmt.Sprintf("Failed to remove router-daemon after DB update failure: %v", cleanupErr)) + } + rollbackLegacy("runtime update failure") + return &ProvisionResult{Router: existing, Capabilities: caps, Steps: steps}, err + } + if err := p.db.UpdateRouterStatus(existing.Addr, "active"); err != nil { + addStep("register", "warning", fmt.Sprintf("Router-daemon connected but status update failed: %v", err)) + } else { + addStep("register", "done", "Router runtime updated to router-daemon") + } + + updated, err := p.db.GetRouterByAddr(existing.Addr) + if err != nil { + updated = existing + } + return &ProvisionResult{Router: updated, Capabilities: caps, Steps: steps}, nil +} + +func (p *Provisioner) DeprovisionDaemon(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) ([]ProvisionStep, error) { + _ = ctx + + steps := []ProvisionStep{} + client, err := p.dial(addr, sshPort, sshUser, sshPass, sshKey) + if err != nil { + steps = append(steps, ProvisionStep{"connect", "error", fmt.Sprintf("SSH failed: %v", err)}) + return steps, err + } + defer client.Close() + steps = append(steps, ProvisionStep{"connect", "done", "Connected"}) + + if err := p.removeDaemonArtifacts(client); err != nil { + steps = append(steps, ProvisionStep{"remove", "error", err.Error()}) + return steps, err + } + + steps = append(steps, + ProvisionStep{"stop", "done", "router-daemon stopped"}, + ProvisionStep{"remove", "done", "router-daemon removed and nft state flushed"}, + ) + return steps, nil +} + // --- helpers --- func sshDial(addr string, port int, user, pass, key string) (remoteClient, error) { @@ -314,6 +572,124 @@ func writeSSHRemoteFile(client *ssh.Client, path, content string) error { return nil } +func writeSSHRemoteBinary(client *ssh.Client, path string, content []byte, mode os.FileMode) error { + session, err := client.NewSession() + if err != nil { + return fmt.Errorf("new session: %w", err) + } + defer session.Close() + + var stdout, stderr bytes.Buffer + session.Stdout = &stdout + session.Stderr = &stderr + session.Stdin = bytes.NewReader(content) + + cmd := fmt.Sprintf("cat > %s && chmod %o %s", shellQuote(path), mode.Perm(), shellQuote(path)) + if err := session.Run(cmd); err != nil { + output := stdout.String() + stderr.String() + return fmt.Errorf("%w (output: %s)", err, strings.TrimSpace(output)) + } + return nil +} + +func (p *Provisioner) applyDaemonDefaults(req *DaemonRequest) { + if req.SSHPort == 0 { + req.SSHPort = 22 + } + if req.SSHUser == "" { + req.SSHUser = "root" + } + if req.NodeName == "" { + req.NodeName = req.Addr + } + if req.DefaultAction == "" { + req.DefaultAction = "deny" + if p.cfg != nil && strings.TrimSpace(p.cfg.UI.DefaultAction) != "" { + req.DefaultAction = p.cfg.UI.DefaultAction + } + } + if req.PollIntervalMS <= 0 { + req.PollIntervalMS = 1000 + } + if req.FirewallBackend == "" { + req.FirewallBackend = "nft" + } +} + +func (p *Provisioner) waitForManagedNode(ctx context.Context, addr string, timeout time.Duration) error { + if p.nodes == nil { + return fmt.Errorf("router-daemon verification requires node manager runtime") + } + + deadline := time.NewTimer(timeout) + defer deadline.Stop() + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + if p.nodes.GetNode(addr) != nil { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-deadline.C: + return fmt.Errorf("timed out waiting for authenticated subscribe") + case <-ticker.C: + } + } +} + +func (p *Provisioner) removeDaemonArtifacts(client remoteClient) error { + stopCmd := "if [ -x " + shellQuote(routerDaemonInitdPath) + " ]; then " + shellQuote(routerDaemonInitdPath) + " stop; " + shellQuote(routerDaemonInitdPath) + " disable; fi" + if _, err := client.Run(stopCmd); err != nil { + return fmt.Errorf("stop router-daemon: %w", err) + } + if _, err := client.Run("nft delete table inet opensnitch-router >/dev/null 2>&1 || true"); err != nil { + return fmt.Errorf("flush router-daemon nft table: %w", err) + } + if _, err := client.Run("rm -rf " + shellQuote(routerDaemonConfigDir) + " " + shellQuote(routerDaemonInitdPath) + " " + shellQuote(routerDaemonBinaryRemotePath)); err != nil { + return fmt.Errorf("remove router-daemon files: %w", err) + } + return nil +} + +func (p *Provisioner) resolveGRPCPublicAddr(routerAddr string) (string, string, error) { + if p.cfg == nil { + return "", "", fmt.Errorf("set server.grpc_public_addr to upgrade routers to router-daemon") + } + + if public := strings.TrimSpace(p.cfg.Server.GRPCPublicAddr); public != "" { + return public, "config_override", nil + } + + host, port, err := net.SplitHostPort(strings.TrimSpace(p.cfg.Server.GRPCAddr)) + if err != nil { + return "", "", fmt.Errorf("invalid server.grpc_addr %q", p.cfg.Server.GRPCAddr) + } + if host != "" && host != "0.0.0.0" && host != "::" && host != "[::]" { + return net.JoinHostPort(host, port), "listen_addr", nil + } + + lanURL, source := ResolveServerURL(routerAddr, p.cfg.Server.HTTPAddr) + if lanURL == "" { + return "", "", fmt.Errorf("could not infer public gRPC address from server config; set server.grpc_public_addr") + } + + parsed, err := url.Parse(lanURL) + if err != nil || parsed.Hostname() == "" { + return "", "", fmt.Errorf("invalid inferred LAN URL %q; set server.grpc_public_addr", lanURL) + } + + return net.JoinHostPort(parsed.Hostname(), port), source, nil +} + +func shellQuote(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'\''`) + "'" +} + func deriveLANPrefix(routerAddr, userSubnet string) string { if userSubnet != "" { // If user provided CIDR like 192.168.1.0/24, derive the prefix diff --git a/internal/router/provisioner_test.go b/internal/router/provisioner_test.go index dde0269..8136539 100644 --- a/internal/router/provisioner_test.go +++ b/internal/router/provisioner_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "path/filepath" "strings" "testing" @@ -25,8 +26,8 @@ type fakeRunResult struct { func newFakeRemoteClient() *fakeRemoteClient { return &fakeRemoteClient{ outputs: map[string]fakeRunResult{ - "cat /etc/openwrt_release 2>/dev/null": {output: "DISTRIB_ID='OpenWrt'\n"}, - "{ which conntrack && wget --version 2>&1 | grep -q GNU; } >/dev/null 2>&1 && echo INSTALLED || echo MISSING": {output: "INSTALLED\n"}, + "cat /etc/openwrt_release 2>/dev/null": {output: "DISTRIB_ID='OpenWrt'\n"}, + "{ which conntrack && wget --version 2>&1 | grep -q GNU; } >/dev/null 2>&1 && echo INSTALLED || echo MISSING": {output: "INSTALLED\n"}, "mkdir -p /etc/conntrack-agent": {}, "chmod +x /etc/conntrack-agent/agent.sh /etc/init.d/conntrack-agent": {}, "/etc/init.d/conntrack-agent enable": {}, @@ -58,6 +59,11 @@ func (c *fakeRemoteClient) WriteFile(path, content string) error { return nil } +func (c *fakeRemoteClient) WriteBinary(path string, content []byte, mode os.FileMode) error { + c.writes[path] = fmt.Sprintf("binary:%d:%o", len(content), mode.Perm()) + return nil +} + func newTestProvisioner(t *testing.T, client remoteClient) (*Provisioner, *db.Database) { t.Helper() diff --git a/internal/rules/router_managed.go b/internal/rules/router_managed.go new file mode 100644 index 0000000..9d2acd7 --- /dev/null +++ b/internal/rules/router_managed.go @@ -0,0 +1,93 @@ +package rules + +import ( + "errors" + "fmt" + "strings" + + pb "github.com/bilalbayram/opensnitch-web/proto" +) + +type RouterManagedRuleError struct { + Message string +} + +func (e *RouterManagedRuleError) Error() string { + return e.Message +} + +var supportedRouterManagedOperands = map[string]struct{}{ + "process.path": {}, + "dest.ip": {}, + "dest.port": {}, + "protocol": {}, + "user.id": {}, +} + +func IsRouterManagedRuleError(err error) bool { + var target *RouterManagedRuleError + ok := errors.As(err, &target) + return ok +} + +func ValidateRouterManagedRule(rule *pb.Rule) error { + if rule == nil { + return &RouterManagedRuleError{Message: "router-daemon rules require an operator"} + } + return ValidateRouterManagedOperator(rule.GetOperator()) +} + +func ValidateRouterManagedOperator(operator *pb.Operator) error { + if operator == nil { + return &RouterManagedRuleError{Message: "router-daemon rules require an operator"} + } + + operatorType := strings.ToLower(strings.TrimSpace(operator.GetType())) + if len(operator.GetList()) > 0 && operatorType == "" { + operatorType = compoundOperatorType + } + switch operatorType { + case "", simpleOperatorType: + return validateRouterManagedOperand(operator) + case compoundOperatorType: + if len(operator.GetList()) == 0 { + return &RouterManagedRuleError{Message: "router-daemon compound rules require at least one operand"} + } + seen := make(map[string]struct{}, len(operator.GetList())) + for _, item := range operator.GetList() { + if item == nil { + return &RouterManagedRuleError{Message: "router-daemon compound rules cannot contain empty operands"} + } + itemType := strings.ToLower(strings.TrimSpace(item.GetType())) + if itemType != "" && itemType != simpleOperatorType { + return &RouterManagedRuleError{Message: "router-daemon rules only support flat compound lists"} + } + operand := strings.TrimSpace(item.GetOperand()) + if _, exists := seen[operand]; exists { + return &RouterManagedRuleError{Message: fmt.Sprintf("router-daemon rules only support one %s operand", operand)} + } + seen[operand] = struct{}{} + if err := validateRouterManagedOperand(item); err != nil { + return err + } + } + return nil + default: + return &RouterManagedRuleError{Message: "router-daemon rules only support simple operators or flat compound lists"} + } +} + +func validateRouterManagedOperand(operator *pb.Operator) error { + operand := strings.TrimSpace(operator.GetOperand()) + if _, ok := supportedRouterManagedOperands[operand]; ok { + return nil + } + + if operand == "" { + return &RouterManagedRuleError{Message: "router-daemon rules require an operand"} + } + + return &RouterManagedRuleError{ + Message: fmt.Sprintf("router-daemon rules do not support %s", operand), + } +} diff --git a/internal/templatesync/service.go b/internal/templatesync/service.go index 13527b8..363e8c3 100644 --- a/internal/templatesync/service.go +++ b/internal/templatesync/service.go @@ -176,6 +176,16 @@ func (s *Service) ReconcileNode(node string) error { } func (s *Service) ResolveManagedRules(node string) ([]*db.DBRule, []*pb.Rule, error) { + routerManaged := false + routerRecord, err := s.db.GetRouterByLinkedNodeAddr(node) + switch { + case err == nil: + routerManaged = routerRecord.DaemonMode == db.RouterDaemonModeRouterDaemon + case err == sql.ErrNoRows: + default: + return nil, nil, err + } + nodeTags, err := s.db.GetNodeTags(node) if err != nil { return nil, nil, err @@ -219,6 +229,11 @@ func (s *Service) ResolveManagedRules(node string) ([]*db.DBRule, []*pb.Rule, er if err != nil { return nil, nil, err } + if routerManaged { + if err := ruleutil.ValidateRouterManagedRule(protoRule); err != nil { + return nil, nil, err + } + } canonicalOperator, err := ruleutil.CanonicalOperatorJSONFromRule(dbRule) if err != nil { diff --git a/web/src/components/prompt/dialog.tsx b/web/src/components/prompt/dialog.tsx index cef1d89..6449e15 100644 --- a/web/src/components/prompt/dialog.tsx +++ b/web/src/components/prompt/dialog.tsx @@ -70,6 +70,7 @@ export function PromptOverlay() { const progressPercent = (countdown / 120) * 100; const urgent = countdown <= 30; + const routerManaged = Boolean(prompt.router_managed); return (
@@ -133,11 +134,16 @@ export function PromptOverlay() { className="w-full bg-muted border border-border rounded-lg px-3 py-2.5 text-sm" > - {prompt.dst_host && } + {!routerManaged && prompt.dst_host && } + {routerManaged && ( +
+ Router-managed prompts only support process path, destination IP, destination port, and user ID operands. +
+ )}
{/* Duration selector — larger touch targets on mobile */} diff --git a/web/src/components/rule-editor-sheet.tsx b/web/src/components/rule-editor-sheet.tsx index a64e5c1..2593be6 100644 --- a/web/src/components/rule-editor-sheet.tsx +++ b/web/src/components/rule-editor-sheet.tsx @@ -41,6 +41,24 @@ export const operandLabels: Record = { protocol: "Protocol", }; +const standardOperandOptions = [ + { value: "process.path", label: "Process Path" }, + { value: "process.command", label: "Process Command" }, + { value: "dest.host", label: "Dest Host" }, + { value: "dest.ip", label: "Dest IP" }, + { value: "dest.port", label: "Dest Port" }, + { value: "user.id", label: "User ID" }, + { value: "protocol", label: "Protocol" }, +]; + +const routerManagedOperandOptions = [ + { value: "process.path", label: "Process Path" }, + { value: "dest.ip", label: "Dest IP" }, + { value: "dest.port", label: "Dest Port" }, + { value: "user.id", label: "User ID" }, + { value: "protocol", label: "Protocol" }, +]; + interface RuleEditorSheetProps { open: boolean; onClose: () => void; @@ -48,6 +66,7 @@ interface RuleEditorSheetProps { editing?: boolean; onSave: (form: RuleForm) => Promise; title?: string; + routerManaged?: boolean; } export function RuleEditorSheet({ @@ -57,14 +76,26 @@ export function RuleEditorSheet({ editing = false, onSave, title, + routerManaged = false, }: RuleEditorSheetProps) { const [form, setForm] = useState({ ...defaultForm }); + const operandOptions = routerManaged + ? routerManagedOperandOptions + : standardOperandOptions; useEffect(() => { if (open) { - setForm({ ...defaultForm, ...initialValues }); + const next = { ...defaultForm, ...initialValues }; + if ( + routerManaged && + !operandOptions.some((option) => option.value === next.operator_operand) + ) { + next.operator_operand = "dest.ip"; + next.operator_data = next.operator_operand === "dest.ip" ? next.operator_data : ""; + } + setForm(next); } - }, [open, initialValues]); + }, [open, initialValues, operandOptions, routerManaged]); const handleSave = async () => { await onSave(form); @@ -96,6 +127,13 @@ export function RuleEditorSheet({
Target node: {form.node || "All nodes"}
+ {routerManaged && ( +
+ Router-managed nodes only support process path, destination IP, + destination port, protocol, and user ID operands. Forwarded traffic + in v1 is device and network rule based, with no live prompts. +
+ )}
@@ -144,13 +182,11 @@ export function RuleEditorSheet({ } className="w-full bg-muted border border-border rounded-lg px-3 py-2 text-sm mt-1" > - - - - - - - + {operandOptions.map((option) => ( + + ))}
@@ -162,7 +198,11 @@ export function RuleEditorSheet({ onChange={(e) => setForm({ ...form, operator_data: e.target.value }) } - placeholder="e.g. /usr/bin/curl or google.com" + placeholder={ + routerManaged + ? "e.g. /usr/bin/curl, 1.1.1.1, 443" + : "e.g. /usr/bin/curl or google.com" + } className="w-full bg-muted border border-border rounded-lg px-3 py-2 text-sm mt-1" /> diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 0fd026e..1d5f255 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -13,11 +13,24 @@ export interface NodeRecord { online: boolean; mode: string; source_type: string; + router_managed: boolean; + linked_router_addr: string; tags: string[]; template_sync_pending: boolean; template_sync_error: string; } +export interface RouterLinkedNodeSummary { + addr: string; + online: boolean; + mode: string; + daemon_version: string; + daemon_rules: number; + cons: number; + cons_dropped: number; + last_connection: string; +} + export interface RouterRecord { id: number; name: string; @@ -25,8 +38,11 @@ export interface RouterRecord { ssh_port: number; ssh_user: string; lan_subnet: string; + daemon_mode: string; + linked_node_addr: string; status: string; online: boolean; + linked_node: RouterLinkedNodeSummary | null; created_at: string; updated_at: string; } @@ -48,13 +64,50 @@ export interface ProvisionStep { message: string; } +export interface RouterCapabilities { + openwrt: boolean; + arch: string; + kernel_version: string; + kernel_supported: boolean; + ram_mb: number; + ram_supported: boolean; + overlay_free_mb: number; + overlay_supported: boolean; + has_nft: boolean; + has_conntrack: boolean; + has_opkg: boolean; + eligible: boolean; + ineligible_reason: string; +} + export interface ConnectRouterResponse { router: RouterRecord; + capabilities?: RouterCapabilities; steps: ProvisionStep[]; server_url?: string; server_url_source?: string; } +export interface RouterCapabilitiesResponse { + capabilities: RouterCapabilities; + steps: ProvisionStep[]; +} + +export interface RouterUpgradeRequest { + ssh_pass: string; + ssh_key?: string; + node_name?: string; + default_action?: string; + poll_interval_ms?: number; + firewall_backend?: string; +} + +export interface RouterDowngradeRequest { + ssh_pass: string; + ssh_key?: string; + server_url?: string; +} + export interface SuggestServerURLResponse { server_url: string; source: string; @@ -676,6 +729,27 @@ export const api = { body: JSON.stringify({ ssh_pass: sshPass }), }, ), + getRouterCapabilities: ( + addr: string, + payload: { ssh_pass: string; ssh_key?: string }, + ) => + request( + `/routers/${encodeURIComponent(addr)}/capabilities`, + { + method: "POST", + body: JSON.stringify(payload), + }, + ), + upgradeRouter: (addr: string, payload: RouterUpgradeRequest) => + request(`/routers/${encodeURIComponent(addr)}/upgrade`, { + method: "POST", + body: JSON.stringify(payload), + }), + downgradeRouter: (addr: string, payload: RouterDowngradeRequest) => + request(`/routers/${encodeURIComponent(addr)}/downgrade`, { + method: "POST", + body: JSON.stringify(payload), + }), // Version & Updates getVersion: () => request("/version"), diff --git a/web/src/lib/rule-helpers.ts b/web/src/lib/rule-helpers.ts index ae2f043..061f669 100644 --- a/web/src/lib/rule-helpers.ts +++ b/web/src/lib/rule-helpers.ts @@ -10,6 +10,18 @@ export interface ConnectionLike { process: string; protocol: string; uid: number; + router_managed?: boolean; +} + +export function isDeviceSource(process: string): boolean { + return process.startsWith("device:"); +} + +export function formatProcessLabel(process: string): string { + if (!isDeviceSource(process)) { + return process; + } + return `Source device ${process.slice("device:".length)}`; } /** Pick the best operand and data from a connection */ @@ -17,6 +29,9 @@ export function smartOperandFromConnection(conn: ConnectionLike): { operand: string; data: string; } { + if (conn.router_managed) { + return { operand: "dest.ip", data: conn.dst_ip }; + } if (conn.dst_host) { return { operand: "dest.host", data: conn.dst_host }; } diff --git a/web/src/pages/connections.tsx b/web/src/pages/connections.tsx index 4ac68c6..b4273f2 100644 --- a/web/src/pages/connections.tsx +++ b/web/src/pages/connections.tsx @@ -6,10 +6,12 @@ import { ResponsiveDataView } from '@/components/ui/responsive-data-view'; import { QuickRulePopover } from '@/components/quick-rule-popover'; import { RuleEditorSheet } from '@/components/rule-editor-sheet'; import type { RuleForm } from '@/components/rule-editor-sheet'; +import { formatProcessLabel } from '@/lib/rule-helpers'; interface NodeInfo { addr: string; hostname: string; + router_managed: boolean; } export default function ConnectionsPage() { @@ -24,6 +26,7 @@ export default function ConnectionsPage() { const [editorOpen, setEditorOpen] = useState(false); const [editorPrefill, setEditorPrefill] = useState | undefined>(); const limit = 50; + const nodeByAddr = Object.fromEntries(nodes.map((node) => [node.addr, node])); useEffect(() => { api.getNodes().then(setNodes).catch(console.error); @@ -158,11 +161,13 @@ export default function ConnectionsPage() { {c.protocol} {c.src_ip}:{c.src_port} {c.dst_host || c.dst_ip}:{c.dst_port} - {c.process} + + {formatProcessLabel(c.process)} + {c.rule} @@ -183,14 +188,14 @@ export default function ConnectionsPage() {
{c.time}
- {truncateMiddle(c.process || '', 60)} + {truncateMiddle(formatProcessLabel(c.process || ''), 60)}
→ {c.dst_host || c.dst_ip}:{c.dst_port} @@ -232,6 +237,7 @@ export default function ConnectionsPage() { open={editorOpen} onClose={() => { setEditorOpen(false); setEditorPrefill(undefined); }} initialValues={editorPrefill} + routerManaged={Boolean(editorPrefill?.node && nodeByAddr[editorPrefill.node]?.router_managed)} onSave={handleEditorSave} title="Create Rule from Connection" /> diff --git a/web/src/pages/nodes.tsx b/web/src/pages/nodes.tsx index 7261983..2e060bd 100644 --- a/web/src/pages/nodes.tsx +++ b/web/src/pages/nodes.tsx @@ -1,5 +1,11 @@ -import { useEffect, useRef, useState } from "react"; -import type { NodeRecord, ProvisionStep, DiscoveredRouter } from "@/lib/api"; +import { useEffect, useMemo, useRef, useState } from "react"; +import type { + NodeRecord, + ProvisionStep, + DiscoveredRouter, + RouterRecord, + RouterCapabilities, +} from "@/lib/api"; import { api } from "@/lib/api"; import { formatUptime } from "@/lib/utils"; import { @@ -52,6 +58,7 @@ interface TrustEntry { export default function NodesPage() { const [nodes, setNodes] = useState([]); + const [routers, setRouters] = useState([]); const [status, setStatus] = useState>({}); const pendingRef = useRef(0); const [trustExpanded, setTrustExpanded] = useState>( @@ -89,19 +96,26 @@ export default function NodesPage() { const [scanResults, setScanResults] = useState(null); const [scanSubnet, setScanSubnet] = useState(""); - // Router disconnect state - const [disconnecting, setDisconnecting] = useState(null); - const [disconnectPass, setDisconnectPass] = useState(""); + const [routerPasswords, setRouterPasswords] = useState>( + {}, + ); + const [routerCapabilities, setRouterCapabilities] = useState< + Record + >({}); + const [routerSteps, setRouterSteps] = useState< + Record + >({}); + const [routerBusy, setRouterBusy] = useState>({}); - const fetchNodes = (force?: boolean) => { - api - .getNodes() - .then((data) => { + const fetchPageData = (force?: boolean) => { + Promise.all([api.getNodes(), api.getRouters()]) + .then(([nodeData, routerData]) => { if (force || pendingRef.current === 0) { - setNodes(data); + setNodes(nodeData); + setRouters(routerData); setTagDrafts((prev) => { const next = { ...prev }; - for (const node of data) { + for (const node of nodeData) { if (!(node.addr in next)) { next[node.addr] = node.tags.join(", "); } @@ -114,8 +128,8 @@ export default function NodesPage() { }; useEffect(() => { - fetchNodes(); - const interval = setInterval(fetchNodes, 5000); + fetchPageData(); + const interval = setInterval(fetchPageData, 5000); return () => clearInterval(interval); }, []); @@ -155,7 +169,7 @@ export default function NodesPage() { showStatus(addr, "Failed"); } finally { pendingRef.current--; - fetchNodes(true); + fetchPageData(true); } }; @@ -172,7 +186,7 @@ export default function NodesPage() { showStatus(addr, "Mode change failed"); } finally { pendingRef.current--; - fetchNodes(true); + fetchPageData(true); } }; @@ -248,7 +262,7 @@ export default function NodesPage() { showStatus(addr, "Tag update failed"); } finally { pendingRef.current--; - fetchNodes(true); + fetchPageData(true); } }; @@ -275,7 +289,7 @@ export default function NodesPage() { }); setShowAdvanced(false); setServerUrlSource(""); - fetchNodes(true); + fetchPageData(true); }, 2000); } catch (e: unknown) { // Try to parse steps from the error response body @@ -293,15 +307,25 @@ export default function NodesPage() { }; const handleDisconnectRouter = async (addr: string) => { - if (!disconnectPass) return; + const sshPass = routerPasswords[addr]?.trim(); + if (!sshPass) { + showStatus(addr, "SSH password required"); + return; + } try { - await api.disconnectRouter(addr, disconnectPass); - setDisconnecting(null); - setDisconnectPass(""); + setRouterBusy((prev) => ({ ...prev, [addr]: "disconnecting" })); + const res = await api.disconnectRouter(addr, sshPass); + setRouterSteps((prev) => ({ ...prev, [addr]: res.steps })); showStatus(addr, "Router disconnected"); - fetchNodes(true); + fetchPageData(true); } catch (e: unknown) { showStatus(addr, e instanceof Error ? e.message : "Disconnect failed"); + } finally { + setRouterBusy((prev) => { + const next = { ...prev }; + delete next[addr]; + return next; + }); } }; @@ -344,6 +368,465 @@ export default function NodesPage() { default: "bg-primary/10 text-primary border-primary/30", }; + const handleRouterPasswordChange = (addr: string, value: string) => { + setRouterPasswords((prev) => ({ ...prev, [addr]: value })); + }; + + const withRouterBusy = async (addr: string, mode: string, fn: () => Promise) => { + setRouterBusy((prev) => ({ ...prev, [addr]: mode })); + try { + await fn(); + } finally { + setRouterBusy((prev) => { + const next = { ...prev }; + delete next[addr]; + return next; + }); + } + }; + + const handleCheckCapabilities = async (router: RouterRecord) => { + const sshPass = routerPasswords[router.addr]?.trim(); + if (!sshPass) { + showStatus(router.addr, "SSH password required"); + return; + } + + await withRouterBusy(router.addr, "capabilities", async () => { + try { + const res = await api.getRouterCapabilities(router.addr, { ssh_pass: sshPass }); + setRouterCapabilities((prev) => ({ ...prev, [router.addr]: res.capabilities })); + setRouterSteps((prev) => ({ ...prev, [router.addr]: res.steps })); + showStatus(router.addr, res.capabilities.eligible ? "Eligible" : "Not eligible"); + } catch (e) { + const err = e as Error & Record; + if (err.capabilities) { + setRouterCapabilities((prev) => ({ + ...prev, + [router.addr]: err.capabilities as RouterCapabilities, + })); + } + if (err.steps) { + setRouterSteps((prev) => ({ + ...prev, + [router.addr]: err.steps as ProvisionStep[], + })); + } + showStatus(router.addr, err.message || "Capability check failed"); + } + }); + }; + + const handleUpgradeRouter = async (router: RouterRecord) => { + const sshPass = routerPasswords[router.addr]?.trim(); + if (!sshPass) { + showStatus(router.addr, "SSH password required"); + return; + } + + await withRouterBusy(router.addr, "upgrade", async () => { + try { + const res = await api.upgradeRouter(router.addr, { + ssh_pass: sshPass, + node_name: router.name || router.addr, + default_action: "deny", + poll_interval_ms: 1000, + firewall_backend: "nft", + }); + setRouterSteps((prev) => ({ ...prev, [router.addr]: res.steps })); + if (res.capabilities) { + setRouterCapabilities((prev) => ({ ...prev, [router.addr]: res.capabilities! })); + } + showStatus(router.addr, "Router upgraded"); + fetchPageData(true); + } catch (e) { + const err = e as Error & Record; + if (err.steps) { + setRouterSteps((prev) => ({ ...prev, [router.addr]: err.steps as ProvisionStep[] })); + } + if (err.capabilities) { + setRouterCapabilities((prev) => ({ + ...prev, + [router.addr]: err.capabilities as RouterCapabilities, + })); + } + showStatus(router.addr, err.message || "Upgrade failed"); + } + }); + }; + + const handleDowngradeRouter = async (router: RouterRecord) => { + const sshPass = routerPasswords[router.addr]?.trim(); + if (!sshPass) { + showStatus(router.addr, "SSH password required"); + return; + } + + await withRouterBusy(router.addr, "downgrade", async () => { + try { + const res = await api.downgradeRouter(router.addr, { ssh_pass: sshPass }); + setRouterSteps((prev) => ({ ...prev, [router.addr]: res.steps })); + showStatus(router.addr, "Router downgraded"); + fetchPageData(true); + } catch (e) { + const err = e as Error & Record; + if (err.steps) { + setRouterSteps((prev) => ({ ...prev, [router.addr]: err.steps as ProvisionStep[] })); + } + showStatus(router.addr, err.message || "Downgrade failed"); + } + }); + }; + + const routerRuntimeNodes = useMemo(() => { + return routers.map((router) => { + const runtimeAddr = + router.daemon_mode === "router-daemon" && router.linked_node_addr + ? router.linked_node_addr + : router.addr; + return { + router, + runtimeNode: nodes.find((node) => node.addr === runtimeAddr), + }; + }); + }, [nodes, routers]); + + const routerNodeAddrs = useMemo(() => { + const result = new Set(); + for (const router of routers) { + result.add(router.addr); + if (router.linked_node_addr) { + result.add(router.linked_node_addr); + } + } + return result; + }, [routers]); + + const genericNodes = useMemo( + () => nodes.filter((node) => !routerNodeAddrs.has(node.addr)), + [nodes, routerNodeAddrs], + ); + + const renderModeControls = (node: NodeRecord) => ( +
+ Mode: +
+ {modeOptions.map((opt) => ( + + ))} +
+
+ ); + + const renderTagsSection = (node: NodeRecord) => ( +
+
+ Current tags: + {node.tags.length > 0 ? ( + node.tags.map((tag) => ( + + {tag} + + )) + ) : ( + No tags + )} + {node.template_sync_pending && ( + + Sync pending + + )} + {!node.template_sync_pending && + !node.template_sync_error && + node.tags.length > 0 && ( + + Synced + + )} +
+ {node.template_sync_error && ( +
{node.template_sync_error}
+ )} +
+ + setTagDrafts((prev) => ({ + ...prev, + [node.addr]: e.target.value, + })) + } + onKeyDown={(e) => e.key === "Enter" && handleSaveTags(node.addr)} + placeholder="server, desktop, iot" + className="flex-1 rounded-lg border border-border bg-card px-3 py-2 text-sm focus:outline-none focus:border-primary" + /> + +
+
+ Tags are normalized to lowercase slugs and trigger template reconciliation. +
+
+ ); + + const renderTrustSection = (node: NodeRecord, title: string, description?: string) => ( +
+ + {description &&
{description}
} + + {trustExpanded[node.addr] && ( +
+
+ + setNewTrustPath((prev) => ({ + ...prev, + [node.addr]: e.target.value, + })) + } + onKeyDown={(e) => e.key === "Enter" && handleAddTrust(node.addr)} + className="flex-1 text-xs px-3 py-2 rounded-lg bg-muted border border-border focus:outline-none focus:border-primary" + /> +
+ + +
+
+ + {trustData[node.addr]?.length > 0 && ( + + Process Path + Scope + Trust Level + + + } + renderRow={(entry: TrustEntry) => ( + + {entry.process_path} + + + {entry.node === "*" ? "Global" : "This node"} + + + +
+ {trustLevelOptions.map((lvl) => ( + + ))} +
+ + + + + + )} + renderCard={(entry: TrustEntry) => ( +
+
+
{entry.process_path}
+ +
+
+ + {entry.node === "*" ? "Global" : "This node"} + +
+
+ {trustLevelOptions.map((lvl) => ( + + ))} +
+
+ )} + /> + )} +
+ )} +
+ ); + + const renderStatusPill = (key: string) => + status[key] ? ( +
+ + {status[key]} + +
+ ) : null; + + const renderProvisionSteps = (addr: string) => { + const steps = routerSteps[addr]; + if (!steps?.length) return null; + return ( +
+
Last router action
+ {steps.map((step, index) => ( +
+ {step.status === "done" ? ( + + ) : step.status === "warning" ? ( + + ) : ( + + )} +
+
{step.step}
+
{step.message}
+
+
+ ))} +
+ ); + }; + + const renderCapabilitySummary = (addr: string) => { + const capability = routerCapabilities[addr]; + if (!capability) return null; + return ( +
+
+
+
Capability check
+
+ {capability.eligible + ? "Router is eligible for router-daemon v1." + : capability.ineligible_reason || "Router is not eligible."} +
+
+ + {capability.eligible ? "Eligible" : "Ineligible"} + +
+
+
Arch: {capability.arch || "-"}
+
Kernel: {capability.kernel_version || "-"}
+
RAM: {capability.ram_mb || 0} MB
+
Overlay: {capability.overlay_free_mb || 0} MB
+
+
+ ); + }; + return (
@@ -357,40 +840,191 @@ export default function NodesPage() {
- {nodes.map((node) => ( + {routerRuntimeNodes.map(({ router, runtimeNode }) => { + const busy = routerBusy[router.addr]; + const managed = router.daemon_mode === "router-daemon"; + const linkedNode = router.linked_node; + const online = linkedNode?.online ?? runtimeNode?.online ?? router.online; + const nodeControls = managed && runtimeNode && runtimeNode.online; + + return ( +
+
+
+
+ +
+
+
{router.name || router.addr}
+
{router.addr}
+
+
+
+ + Router + + + {router.daemon_mode} + + + {online ? "Online" : "Offline"} + +
+
+ +
+
+
Runtime
+
{linkedNode?.daemon_version || runtimeNode?.daemon_version || "-"}
+
+
+
Mode
+
{linkedNode?.mode || runtimeNode?.mode || "-"}
+
+
+
Connections
+
{linkedNode?.cons ?? runtimeNode?.cons ?? 0}
+
+
+
Rules
+
{linkedNode?.daemon_rules ?? runtimeNode?.daemon_rules ?? 0}
+
+
+ + {managed ? ( +
+ v1 only prompts for router-local processes. Forwarded device flows are observed and enforced only by explicit device-scoped rules; unknown forwarded traffic is allowed until a rule exists. +
+ ) : ( +
+ Legacy conntrack-agent mode reports forwarded traffic over HTTP ingest. Upgrade to router-daemon for router-local prompts and inline runtime controls. +
+ )} + + {runtimeNode && renderTagsSection(runtimeNode)} + +
+
Router actions
+ handleRouterPasswordChange(router.addr, e.target.value)} + placeholder="SSH password for upgrade, downgrade, and disconnect" + className="w-full rounded-lg border border-border bg-card px-3 py-2 text-sm focus:outline-none focus:border-primary" + /> +
+ + {managed ? ( + + ) : ( + + )} + +
+
+ + {nodeControls && renderModeControls(runtimeNode)} + + {nodeControls && ( +
+ + + + +
+ )} + + {managed && runtimeNode && renderTrustSection( + runtimeNode, + "Router-local trust list", + "Applies only to router-local processes. Forwarded device traffic never enters the prompt flow in v1.", + )} + + {renderStatusPill(router.addr)} + {renderCapabilitySummary(router.addr)} + {renderProvisionSteps(router.addr)} +
+ ); + })} + + {genericNodes.map((node) => (
-
+
-
- {node.hostname || node.addr} -
-
- {node.addr} -
+
{node.hostname || node.addr}
+
{node.addr}
- {node.source_type === "router" && ( - - Router - - )} {node.online ? "Online" : "Offline"} @@ -405,9 +1039,7 @@ export default function NodesPage() {
Uptime
-
- {node.daemon_uptime ? formatUptime(node.daemon_uptime) : "-"} -
+
{node.daemon_uptime ? formatUptime(node.daemon_uptime) : "-"}
Connections
@@ -419,164 +1051,23 @@ export default function NodesPage() {
- {/* Mode selector — only for OpenSnitch nodes */} - {node.source_type !== "router" && ( -
- Mode: -
- {modeOptions.map((opt) => ( - - ))} -
-
- )} + {renderModeControls(node)} + {renderStatusPill(node.addr)} + {renderTagsSection(node)} - {/* Router disconnect */} - {node.source_type === "router" && ( -
- {disconnecting === node.addr ? ( -
- setDisconnectPass(e.target.value)} - onKeyDown={(e) => e.key === "Enter" && handleDisconnectRouter(node.addr)} - className="flex-1 text-xs px-3 py-2 rounded-lg bg-muted border border-border focus:outline-none focus:border-destructive" - /> - - -
- ) : ( - - )} -
- )} - - {status[node.addr] && ( -
- - {status[node.addr]} - -
- )} - -
-
- Current tags: - {node.tags.length > 0 ? ( - node.tags.map((tag) => ( - - {tag} - - )) - ) : ( - No tags - )} - {node.template_sync_pending && ( - - Sync pending - - )} - {!node.template_sync_pending && - !node.template_sync_error && - node.tags.length > 0 && ( - - Synced - - )} -
- {node.template_sync_error && ( -
- {node.template_sync_error} -
- )} -
- - setTagDrafts((prev) => ({ - ...prev, - [node.addr]: e.target.value, - })) - } - onKeyDown={(e) => - e.key === "Enter" && handleSaveTags(node.addr) - } - placeholder="server, desktop, iot" - className="flex-1 rounded-lg border border-border bg-card px-3 py-2 text-sm focus:outline-none focus:border-primary" - /> - -
-
- Tags are normalized to lowercase slugs and trigger template - reconciliation. -
-
- - {node.online && node.source_type !== "router" && ( + {node.online && (
)} - {/* Trust List — only for OpenSnitch nodes */} - {node.source_type !== "router" &&
- - - {trustExpanded[node.addr] && ( -
- {/* Add new entry — stack on mobile */} -
- - setNewTrustPath((prev) => ({ - ...prev, - [node.addr]: e.target.value, - })) - } - onKeyDown={(e) => - e.key === "Enter" && handleAddTrust(node.addr) - } - className="flex-1 text-xs px-3 py-2 rounded-lg bg-muted border border-border focus:outline-none focus:border-primary" - /> -
- - -
-
- - {/* Trust entries */} - {trustData[node.addr]?.length > 0 && ( - - - Process Path - - - Scope - - - Trust Level - - - - } - renderRow={(entry: TrustEntry) => ( - - - {entry.process_path} - - - - {entry.node === "*" ? "Global" : "This node"} - - - -
- {trustLevelOptions.map((lvl) => ( - - ))} -
- - - - - - )} - renderCard={(entry: TrustEntry) => ( -
-
-
- {entry.process_path} -
- -
-
- - {entry.node === "*" ? "Global" : "This node"} - -
-
- {trustLevelOptions.map((lvl) => ( - - ))} -
-
- )} - /> - )} -
- )} -
} + {renderTrustSection(node, "Trust List")}
))} - {nodes.length === 0 && ( + + {genericNodes.length === 0 && routerRuntimeNodes.length === 0 && (
No nodes found. Configure an OpenSnitch daemon to connect to this server, or connect a router. diff --git a/web/src/pages/rules.tsx b/web/src/pages/rules.tsx index 968b677..073b412 100644 --- a/web/src/pages/rules.tsx +++ b/web/src/pages/rules.tsx @@ -7,6 +7,7 @@ import { ResponsiveDataView } from "@/components/ui/responsive-data-view"; import { BottomSheet } from "@/components/ui/bottom-sheet"; import { RuleEditorSheet, defaultForm, operandLabels } from "@/components/rule-editor-sheet"; import type { RuleForm } from "@/components/rule-editor-sheet"; +import { formatProcessLabel, isDeviceSource } from "@/lib/rule-helpers"; interface GeneratedRulePreview { fingerprint: string; @@ -69,6 +70,13 @@ function flattenOperators(operator?: RuleOperator): RuleOperator[] { } function formatOperator(operator: RuleOperator) { + if ( + operator.operand === "process.path" && + operator.data && + isDeviceSource(operator.data) + ) { + return formatProcessLabel(operator.data); + } const label = operandLabels[operator.operand || ""] || operator.operand || "Match"; const value = @@ -127,6 +135,12 @@ export default function RulesPage() { () => nodes.find((node) => node.addr === selectedNode), [nodes, selectedNode], ); + const routerManagedScope = useMemo(() => { + if (selectedNode) { + return Boolean(selectedNodeInfo?.router_managed); + } + return nodes.some((node) => node.router_managed); + }, [nodes, selectedNode, selectedNodeInfo]); const fetchNodes = () => { api.getNodes().then(setNodes).catch(console.error); @@ -372,6 +386,11 @@ export default function RulesPage() { {modeLabels[selectedNodeInfo.mode] || selectedNodeInfo.mode} {selectedNodeInfo.hostname || selectedNodeInfo.addr} + {selectedNodeInfo.router_managed && ( + + Router-managed + + )}
)}
@@ -619,6 +638,7 @@ export default function RulesPage() { onClose={() => setShowEditor(false)} initialValues={form} editing={editing} + routerManaged={routerManagedScope} onSave={handleSave} /> diff --git a/web/src/stores/app-store.ts b/web/src/stores/app-store.ts index 58739e4..931f5c4 100644 --- a/web/src/stores/app-store.ts +++ b/web/src/stores/app-store.ts @@ -24,6 +24,7 @@ export interface Prompt { id: string; node_addr: string; created_at: string; + router_managed: boolean; process: string; dst_host: string; dst_ip: string; From 6a86a4da727fb9bed73031d0f11a3677fdb9dd9e Mon Sep 17 00:00:00 2001 From: bilalbayram Date: Mon, 23 Mar 2026 01:03:52 +0300 Subject: [PATCH 2/4] fix: improve router onboarding auth and mode selection --- internal/api/handlers_routers.go | 49 ++++++++-- internal/api/handlers_routers_test.go | 136 +++++++++++++++++++++++++- internal/router/provisioner.go | 25 +++++ internal/router/provisioner_test.go | 78 +++++++++++++++ web/src/lib/api.ts | 2 + web/src/pages/nodes.tsx | 127 +++++++++++++++++++----- 6 files changed, 384 insertions(+), 33 deletions(-) diff --git a/internal/api/handlers_routers.go b/internal/api/handlers_routers.go index 9d829ea..20aef05 100644 --- a/internal/api/handlers_routers.go +++ b/internal/api/handlers_routers.go @@ -57,6 +57,13 @@ func (a *API) handleConnectRouter(w http.ResponseWriter, r *http.Request) { req.Name = req.Addr } + connectMode := router.NormalizeConnectMode(req.Mode) + if connectMode == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "mode must be monitor or manage"}) + return + } + req.Mode = connectMode + // Resolve server URL: user override > LAN auto-detect > Host header fallback var serverURLSource string if req.ServerURL != "" { @@ -94,15 +101,41 @@ func (a *API) handleConnectRouter(w http.ResponseWriter, r *http.Request) { result.ServerURL = req.ServerURL result.ServerURLSource = serverURLSource - // Register as a node (UpsertRouterNode preserves existing cons counter) - a.db.UpsertRouterNode(result.Router.Addr, result.Router.Name, "conntrack-agent", db.NodeStatusOnline, time.Now().Format("2006-01-02 15:04:05")) + if connectMode == router.ConnectModeManage { + daemonResult, daemonErr := a.routerProv.ProvisionDaemon(r.Context(), router.DaemonRequest{ + Addr: req.Addr, + SSHPort: req.SSHPort, + SSHUser: req.SSHUser, + SSHPass: req.SSHPass, + SSHKey: req.SSHKey, + NodeName: req.Name, + }) + if daemonResult != nil { + result.Steps = append(result.Steps, daemonResult.Steps...) + if daemonResult.Router != nil { + result.Router = daemonResult.Router + } + if daemonResult.Capabilities != nil { + result.Capabilities = daemonResult.Capabilities + } + } + if daemonErr != nil { + result.Warning = fmt.Sprintf("Router connected in monitor mode. Manage setup failed: %v", daemonErr) + } + } - a.hub.BroadcastEvent(ws.EventNodeConnected, map[string]any{ - "addr": result.Router.Addr, - "hostname": result.Router.Name, - "version": "conntrack-agent", - "source_type": "router", - }) + legacyConnected := result.Router != nil && result.Router.DaemonMode != db.RouterDaemonModeRouterDaemon + if legacyConnected { + // Register as a node (UpsertRouterNode preserves existing cons counter) + a.db.UpsertRouterNode(result.Router.Addr, result.Router.Name, "conntrack-agent", db.NodeStatusOnline, time.Now().Format("2006-01-02 15:04:05")) + + a.hub.BroadcastEvent(ws.EventNodeConnected, map[string]any{ + "addr": result.Router.Addr, + "hostname": result.Router.Name, + "version": "conntrack-agent", + "source_type": "router", + }) + } writeJSON(w, http.StatusOK, result) } diff --git a/internal/api/handlers_routers_test.go b/internal/api/handlers_routers_test.go index 35e0b4c..d644127 100644 --- a/internal/api/handlers_routers_test.go +++ b/internal/api/handlers_routers_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "database/sql" "encoding/json" "errors" "net/http" @@ -17,12 +18,16 @@ import ( ) type stubRouterProvisioner struct { - steps []routerpkg.ProvisionStep - err error + provisionResult *routerpkg.ProvisionResult + provisionErr error + daemonResult *routerpkg.ProvisionResult + daemonErr error + steps []routerpkg.ProvisionStep + err error } func (s *stubRouterProvisioner) Provision(ctx context.Context, req routerpkg.ConnectRequest) (*routerpkg.ProvisionResult, error) { - return nil, errors.New("not implemented") + return s.provisionResult, s.provisionErr } func (s *stubRouterProvisioner) Deprovision(ctx context.Context, addr string, sshPort int, sshUser, sshPass, sshKey string) ([]routerpkg.ProvisionStep, error) { @@ -34,6 +39,9 @@ func (s *stubRouterProvisioner) CheckCapabilities(ctx context.Context, addr stri } func (s *stubRouterProvisioner) ProvisionDaemon(ctx context.Context, req routerpkg.DaemonRequest) (*routerpkg.ProvisionResult, error) { + if s.daemonResult != nil || s.daemonErr != nil { + return s.daemonResult, s.daemonErr + } return &routerpkg.ProvisionResult{Steps: s.steps}, s.err } @@ -180,6 +188,128 @@ func TestHandleDisconnectRouterFailurePreservesState(t *testing.T) { } } +func TestHandleConnectRouterRejectsInvalidMode(t *testing.T) { + env := newAPITestEnv(t) + + rec := performJSONRequest(t, env.api.handleConnectRouter, http.MethodPost, "/api/v1/routers/connect", map[string]string{ + "addr": "192.168.1.1", + "ssh_pass": "secret", + "mode": "invalid", + }) + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestHandleConnectRouterManageFailureFallsBackToMonitor(t *testing.T) { + env := newAPITestEnv(t) + env.api.routerProv = &stubRouterProvisioner{ + provisionResult: &routerpkg.ProvisionResult{ + Router: &db.Router{ + Name: "router-a", + Addr: "192.168.1.1", + SSHPort: 22, + SSHUser: "root", + DaemonMode: db.RouterDaemonModeConntrackAgent, + }, + Steps: []routerpkg.ProvisionStep{{Step: "connect", Status: "done", Message: "Connected"}}, + }, + daemonResult: &routerpkg.ProvisionResult{ + Router: &db.Router{ + Name: "router-a", + Addr: "192.168.1.1", + SSHPort: 22, + SSHUser: "root", + DaemonMode: db.RouterDaemonModeConntrackAgent, + }, + Capabilities: &routerpkg.RouterCapabilities{ + RAMMB: 64, + RAMSupported: false, + IneligibleReason: "router-daemon v1 requires at least 128MB RAM", + }, + Steps: []routerpkg.ProvisionStep{{Step: "capabilities", Status: "error", Message: "router-daemon v1 requires at least 128MB RAM"}}, + }, + daemonErr: errors.New("router-daemon v1 requires at least 128MB RAM"), + } + + rec := performJSONRequest(t, env.api.handleConnectRouter, http.MethodPost, "/api/v1/routers/connect", map[string]string{ + "addr": "192.168.1.1", + "ssh_pass": "secret", + "mode": "manage", + }) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + response := decodeJSON[struct { + Warning string `json:"warning"` + Capabilities *routerpkg.RouterCapabilities `json:"capabilities"` + Steps []routerpkg.ProvisionStep `json:"steps"` + Router struct { + DaemonMode string `json:"daemon_mode"` + } `json:"router"` + }](t, rec) + if response.Router.DaemonMode != db.RouterDaemonModeConntrackAgent { + t.Fatalf("expected router to remain in monitor mode, got %+v", response.Router) + } + if response.Warning == "" { + t.Fatalf("expected warning in response") + } + if len(response.Steps) != 2 { + t.Fatalf("expected combined steps, got %+v", response.Steps) + } + if response.Capabilities == nil || response.Capabilities.RAMSupported { + t.Fatalf("expected capabilities to be returned on manage fallback, got %+v", response.Capabilities) + } + + node, err := env.database.GetNode("192.168.1.1") + if err != nil { + t.Fatalf("expected monitor node record to be created, got %v", err) + } + if node.DaemonVersion != "conntrack-agent" { + t.Fatalf("expected legacy node version to remain conntrack-agent, got %+v", node) + } +} + +func TestHandleConnectRouterManageSuccessSkipsLegacyNodeUpsert(t *testing.T) { + env := newAPITestEnv(t) + env.api.routerProv = &stubRouterProvisioner{ + provisionResult: &routerpkg.ProvisionResult{ + Router: &db.Router{ + Name: "router-a", + Addr: "192.168.1.1", + SSHPort: 22, + SSHUser: "root", + DaemonMode: db.RouterDaemonModeConntrackAgent, + }, + Steps: []routerpkg.ProvisionStep{{Step: "connect", Status: "done", Message: "Connected"}}, + }, + daemonResult: &routerpkg.ProvisionResult{ + Router: &db.Router{ + Name: "router-a", + Addr: "192.168.1.1", + SSHPort: 22, + SSHUser: "root", + DaemonMode: db.RouterDaemonModeRouterDaemon, + }, + Steps: []routerpkg.ProvisionStep{{Step: "verify", Status: "done", Message: "router-daemon subscribed"}}, + }, + } + + rec := performJSONRequest(t, env.api.handleConnectRouter, http.MethodPost, "/api/v1/routers/connect", map[string]string{ + "addr": "192.168.1.1", + "ssh_pass": "secret", + "mode": "manage", + }) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if _, err := env.database.GetNode("192.168.1.1"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected no legacy node upsert after manage success, got %v", err) + } +} + func TestHandleDisconnectRouterSuccessRemovesState(t *testing.T) { env := newAPITestEnv(t) env.api.routerProv = &stubRouterProvisioner{ diff --git a/internal/router/provisioner.go b/internal/router/provisioner.go index bd90fca..146eda7 100644 --- a/internal/router/provisioner.go +++ b/internal/router/provisioner.go @@ -63,6 +63,7 @@ type ConnectRequest struct { Name string `json:"name"` LANSubnet string `json:"lan_subnet"` ServerURL string `json:"server_url,omitempty"` + Mode string `json:"mode,omitempty"` } type ProvisionResult struct { @@ -71,6 +72,7 @@ type ProvisionResult struct { Steps []ProvisionStep `json:"steps"` ServerURL string `json:"server_url"` ServerURLSource string `json:"server_url_source"` + Warning string `json:"warning,omitempty"` } type DaemonRequest struct { @@ -97,6 +99,9 @@ const ( routerDaemonConfigDir = "/etc/opensnitchd-router" routerDaemonConfigPath = "/etc/opensnitchd-router/config.json" routerDaemonInitdPath = "/etc/init.d/opensnitchd-router" + + ConnectModeMonitor = "monitor" + ConnectModeManage = "manage" ) func NewProvisioner(database *db.Database) *Provisioner { @@ -108,6 +113,17 @@ func NewProvisioner(database *db.Database) *Provisioner { } } +func NormalizeConnectMode(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "", ConnectModeMonitor: + return ConnectModeMonitor + case ConnectModeManage: + return ConnectModeManage + default: + return "" + } +} + func (p *Provisioner) WithRuntime(nodes *nodemanager.Manager, cfg *config.Config) *Provisioner { p.nodes = nodes p.cfg = cfg @@ -530,6 +546,15 @@ func sshDial(addr string, port int, user, pass, key string) (remoteClient, error } if pass != "" { authMethods = append(authMethods, ssh.Password(pass)) + authMethods = append(authMethods, ssh.KeyboardInteractive( + func(_ string, _ string, questions []string, _ []bool) ([]string, error) { + answers := make([]string, len(questions)) + for i := range questions { + answers[i] = pass + } + return answers, nil + }, + )) } config := &ssh.ClientConfig{ User: user, diff --git a/internal/router/provisioner_test.go b/internal/router/provisioner_test.go index 8136539..5c9769f 100644 --- a/internal/router/provisioner_test.go +++ b/internal/router/provisioner_test.go @@ -2,8 +2,11 @@ package router import ( "context" + "crypto/rand" + "crypto/rsa" "errors" "fmt" + "net" "os" "path/filepath" "strings" @@ -11,6 +14,7 @@ import ( "time" "github.com/bilalbayram/opensnitch-web/internal/db" + "golang.org/x/crypto/ssh" ) type fakeRemoteClient struct { @@ -271,3 +275,77 @@ func (c *fakeRemoteClientWithConnCheck) Run(cmd string) (string, error) { } return c.fakeRemoteClient.Run(cmd) } + +func TestSSHDialSupportsKeyboardInteractivePasswordAuth(t *testing.T) { + signer, err := testSigner() + if err != nil { + t.Fatalf("create host key: %v", err) + } + + serverConfig := &ssh.ServerConfig{ + KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + answers, err := challenge(conn.User(), "keyboard-interactive", []string{"Password: "}, []bool{false}) + if err != nil { + return nil, err + } + if len(answers) != 1 || answers[0] != "secret" { + return nil, errors.New("bad password") + } + return nil, nil + }, + } + serverConfig.AddHostKey(signer) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer listener.Close() + + done := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + done <- err + return + } + defer conn.Close() + + _, chans, reqs, err := ssh.NewServerConn(conn, serverConfig) + if err != nil { + done <- err + return + } + go ssh.DiscardRequests(reqs) + for ch := range chans { + ch.Reject(ssh.UnknownChannelType, "unsupported") + } + done <- nil + }() + + addr := listener.Addr().(*net.TCPAddr) + client, err := sshDial("127.0.0.1", addr.Port, "root", "secret", "") + if err != nil { + t.Fatalf("ssh dial: %v", err) + } + if err := client.Close(); err != nil { + t.Fatalf("close client: %v", err) + } + + select { + case err := <-done: + if err != nil { + t.Fatalf("server handshake: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for keyboard-interactive handshake") + } +} + +func testSigner() (ssh.Signer, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(privateKey) +} diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 1d5f255..021f75b 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -56,6 +56,7 @@ export interface ConnectRouterRequest { name?: string; lan_subnet?: string; server_url?: string; + mode?: "monitor" | "manage"; } export interface ProvisionStep { @@ -86,6 +87,7 @@ export interface ConnectRouterResponse { steps: ProvisionStep[]; server_url?: string; server_url_source?: string; + warning?: string; } export interface RouterCapabilitiesResponse { diff --git a/web/src/pages/nodes.tsx b/web/src/pages/nodes.tsx index 2e060bd..5385680 100644 --- a/web/src/pages/nodes.tsx +++ b/web/src/pages/nodes.tsx @@ -49,6 +49,45 @@ const modeOptions = [ }, ]; +type RouterConnectMode = "monitor" | "manage"; + +interface RouterFormState { + addr: string; + ssh_port: number; + ssh_user: string; + ssh_pass: string; + ssh_key: string; + name: string; + lan_subnet: string; + server_url: string; + mode: RouterConnectMode; +} + +const defaultRouterForm: RouterFormState = { + addr: "", + ssh_port: 22, + ssh_user: "root", + ssh_pass: "", + ssh_key: "", + name: "", + lan_subnet: "", + server_url: "", + mode: "monitor", +}; + +const routerConnectModeOptions = [ + { + value: "monitor" as const, + label: "Monitor", + description: "Best compatibility. Tracks forwarded traffic with the legacy router agent.", + }, + { + value: "manage" as const, + label: "Manage", + description: "Deploys router-daemon for router-local prompts and runtime controls when supported.", + }, +]; + interface TrustEntry { id: number; node: string; @@ -73,16 +112,7 @@ export default function NodesPage() { // Router connection state const [showConnectRouter, setShowConnectRouter] = useState(false); - const [routerForm, setRouterForm] = useState({ - addr: "", - ssh_port: 22, - ssh_user: "root", - ssh_pass: "", - ssh_key: "", - name: "", - lan_subnet: "", - server_url: "", - }); + const [routerForm, setRouterForm] = useState({ ...defaultRouterForm }); const [showAdvanced, setShowAdvanced] = useState(false); const [serverUrlSource, setServerUrlSource] = useState(""); const [connecting, setConnecting] = useState(false); @@ -90,6 +120,7 @@ export default function NodesPage() { null, ); const [connectError, setConnectError] = useState(""); + const [connectWarning, setConnectWarning] = useState(""); // Network scan state const [scanning, setScanning] = useState(false); @@ -269,27 +300,24 @@ export default function NodesPage() { const handleConnectRouter = async () => { setConnecting(true); setConnectError(""); + setConnectWarning(""); setConnectSteps(null); try { const res = await api.connectRouter(routerForm); setConnectSteps(res.steps); + fetchPageData(true); setConnecting(false); + if (res.warning) { + setConnectWarning(res.warning); + return; + } setTimeout(() => { setShowConnectRouter(false); setConnectSteps(null); - setRouterForm({ - addr: "", - ssh_port: 22, - ssh_user: "root", - ssh_pass: "", - ssh_key: "", - name: "", - lan_subnet: "", - server_url: "", - }); + setRouterForm({ ...defaultRouterForm }); setShowAdvanced(false); setServerUrlSource(""); - fetchPageData(true); + setConnectWarning(""); }, 2000); } catch (e: unknown) { // Try to parse steps from the error response body @@ -301,6 +329,7 @@ export default function NodesPage() { if (errorSteps) { setConnectSteps(errorSteps); } + setConnectWarning(""); setConnectError(err.message || "Connection failed"); setConnecting(false); } @@ -827,6 +856,10 @@ export default function NodesPage() { ); }; + const canConnectRouter = Boolean( + routerForm.addr.trim() && (routerForm.ssh_pass || routerForm.ssh_key.trim()), + ); + return (
@@ -1103,6 +1136,7 @@ export default function NodesPage() { if (!connecting) { setShowConnectRouter(false); setConnectError(""); + setConnectWarning(""); setConnectSteps(null); setScanResults(null); } @@ -1112,7 +1146,7 @@ export default function NodesPage() { !connectSteps && (
)} + {connectWarning && ( +
+ {connectWarning} +
+ )} + {connectError && ( +
+ {connectError} +
+ )}
) : ( // Show form @@ -1312,6 +1356,40 @@ export default function NodesPage() { />
+
+ +
+ {routerConnectModeOptions.map((option) => { + const selected = routerForm.mode === option.value; + return ( + + ); + })} +
+

+ Manage may fail on smaller routers. It needs enough RAM, storage, and a supported OpenWrt target. If that happens, the router stays connected in monitor mode. +

+
+
)} + {connectWarning && ( +
+ {connectWarning} +
+ )} )}
From fc3cb9e87882e83079edc36c457185a02b56c17e Mon Sep 17 00:00:00 2001 From: bilalbayram Date: Mon, 23 Mar 2026 01:25:30 +0300 Subject: [PATCH 3/4] fix: clarify offline router errors and add node deletion --- internal/api/handlers_nodes.go | 35 ++++++ internal/api/handlers_nodes_test.go | 184 ++++++++++++++++++++++++++++ internal/api/router.go | 1 + internal/db/nodes.go | 57 +++++++++ internal/router/provisioner.go | 90 +++++++++++++- internal/router/provisioner_test.go | 37 ++++++ web/src/lib/api.ts | 4 + web/src/pages/nodes.tsx | 57 ++++++++- 8 files changed, 454 insertions(+), 11 deletions(-) diff --git a/internal/api/handlers_nodes.go b/internal/api/handlers_nodes.go index 584050d..fb86cfe 100644 --- a/internal/api/handlers_nodes.go +++ b/internal/api/handlers_nodes.go @@ -276,6 +276,41 @@ func (a *API) handleSetNodeMode(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } +func (a *API) handleDeleteNode(w http.ResponseWriter, r *http.Request) { + addr := chi.URLParam(r, "addr") + + node, err := a.db.GetNode(addr) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "node not found"}) + return + } + + if node.SourceType == "router" { + writeJSON(w, http.StatusConflict, map[string]string{"error": "router nodes must be removed from the router controls"}) + return + } + + if linkedRouter, err := a.db.GetRouterByLinkedNodeAddr(addr); err == nil && linkedRouter != nil { + writeJSON(w, http.StatusConflict, map[string]string{"error": "router-managed nodes must be removed from the router controls"}) + return + } else if err != nil && err != sql.ErrNoRows { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + if a.nodes.GetNode(addr) != nil { + writeJSON(w, http.StatusConflict, map[string]string{"error": "disconnect the node before deleting it"}) + return + } + + if err := a.db.DeleteNode(addr); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + func (a *API) handleNodeAction(enable bool, isFirewall bool) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { addr := chi.URLParam(r, "addr") diff --git a/internal/api/handlers_nodes_test.go b/internal/api/handlers_nodes_test.go index bf25a84..198df6e 100644 --- a/internal/api/handlers_nodes_test.go +++ b/internal/api/handlers_nodes_test.go @@ -1,6 +1,8 @@ package api import ( + "database/sql" + "errors" "net/http" "testing" "time" @@ -106,3 +108,185 @@ func TestHandleGetNodesMarksStaleRouterOffline(t *testing.T) { t.Fatalf("expected stale router status offline, got %+v", response[0]) } } + +func TestHandleDeleteNodeRemovesStoredNodeData(t *testing.T) { + env := newAPITestEnv(t) + env.seedNode(t, "node-a", false) + + template, err := env.database.CreateRuleTemplate(&db.RuleTemplate{Name: "Template", Description: "test"}) + if err != nil { + t.Fatalf("create template: %v", err) + } + if _, err := env.database.CreateTemplateAttachment(&db.TemplateAttachment{ + TemplateID: template.ID, + TargetType: "node", + TargetRef: "node-a", + Priority: 1, + }); err != nil { + t.Fatalf("create template attachment: %v", err) + } + if err := env.database.UpsertRule(&db.DBRule{ + Time: "2026-03-23 10:00:00", + Node: "node-a", + Name: "allow-example", + DisplayName: "allow-example", + SourceKind: db.RuleSourceManual, + Enabled: true, + Action: "allow", + Duration: "always", + OperatorType: "simple", + OperatorOperand: "dest.host", + OperatorData: "example.com", + Created: "2026-03-23 10:00:00", + }); err != nil { + t.Fatalf("upsert rule: %v", err) + } + if _, err := env.database.ReplaceNodeTags("node-a", []string{"server"}); err != nil { + t.Fatalf("replace tags: %v", err) + } + if _, err := env.database.AddProcessTrust("node-a", "/opt/test-app", db.TrustLevelTrusted); err != nil { + t.Fatalf("add process trust: %v", err) + } + if err := env.database.InsertConnection(&db.Connection{ + Time: "2026-03-23 10:00:00", + Node: "node-a", + Action: "allow", + Protocol: "tcp", + DstIP: "93.184.216.34", + DstHost: "example.com", + DstPort: 443, + Process: "/usr/bin/curl", + }); err != nil { + t.Fatalf("insert connection: %v", err) + } + if err := env.database.InsertAlert(&db.DBAlert{ + Time: "2026-03-23 10:00:00", + Node: "node-a", + Body: "test alert", + Status: "open", + }); err != nil { + t.Fatalf("insert alert: %v", err) + } + if err := env.database.UpsertSeenFlow(db.SeenFlowKey{ + Node: "node-a", + Process: "/usr/bin/curl", + Protocol: "tcp", + DstPort: 443, + DestinationOperand: "dest.host", + Destination: "example.com", + }, "allow", "allow-example", time.Now(), time.Time{}); err != nil { + t.Fatalf("upsert seen flow: %v", err) + } + if err := env.database.UpsertDNSDomain("node-a", "example.com", "93.184.216.34", "2026-03-23 10:00:00"); err != nil { + t.Fatalf("upsert dns domain: %v", err) + } + for table, what := range map[string]string{ + "hosts": "example.com", + "procs": "/opt/test-app", + "addrs": "93.184.216.34", + "ports": "443", + "users": "root", + } { + if err := env.database.UpsertStat(table, what, "node-a", 3); err != nil { + t.Fatalf("upsert %s stat: %v", table, err) + } + } + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteNode, http.MethodDelete, "/api/v1/nodes/node-a", "node-a", nil) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if _, err := env.database.GetNode("node-a"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected node to be removed, got %v", err) + } + if rules, err := env.database.GetRules("node-a"); err != nil || len(rules) != 0 { + t.Fatalf("expected rules removed, got %v %+v", err, rules) + } + if tags, err := env.database.GetNodeTags("node-a"); err != nil || len(tags) != 0 { + t.Fatalf("expected tags removed, got %v %+v", err, tags) + } + if trust := env.database.GetProcessTrustLevel("node-a", "/opt/test-app"); trust != db.TrustLevelDefault { + t.Fatalf("expected node-specific trust removed, got %q", trust) + } + if _, total, err := env.database.GetConnections(&db.ConnectionFilter{Node: "node-a"}); err != nil || total != 0 { + t.Fatalf("expected connections removed, got total=%d err=%v", total, err) + } + if _, total, err := env.database.GetDNSDomains(&db.DNSDomainFilter{Node: "node-a"}); err != nil || total != 0 { + t.Fatalf("expected dns domains removed, got total=%d err=%v", total, err) + } + if _, total, err := env.database.GetSeenFlows(&db.SeenFlowFilter{Node: "node-a"}); err != nil || total != 0 { + t.Fatalf("expected seen flows removed, got total=%d err=%v", total, err) + } + if attachments, err := env.database.GetAllTemplateAttachments(); err != nil || len(attachments) != 0 { + t.Fatalf("expected template attachments removed, got %v %+v", err, attachments) + } + if alerts, total, err := env.database.GetAlerts(50, 0); err != nil || total != 0 || len(alerts) != 0 { + t.Fatalf("expected alerts removed, got total=%d len=%d err=%v", total, len(alerts), err) + } + for _, table := range []string{"hosts", "procs", "addrs", "ports", "users"} { + entries, err := env.database.GetStats(table, 10) + if err != nil { + t.Fatalf("get %s stats: %v", table, err) + } + for _, entry := range entries { + if entry.Node == "node-a" { + t.Fatalf("expected %s stats removed, got %+v", table, entry) + } + } + } +} + +func TestHandleDeleteNodeRejectsOnlineNode(t *testing.T) { + env := newAPITestEnv(t) + env.seedNode(t, "node-a", true) + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteNode, http.MethodDelete, "/api/v1/nodes/node-a", "node-a", nil) + if rec.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestHandleDeleteNodeRejectsRouterBackedNode(t *testing.T) { + env := newAPITestEnv(t) + + if err := env.database.UpsertNode(&db.Node{ + Addr: "router-node", + Hostname: "router-node", + DaemonVersion: "conntrack-agent", + Status: db.NodeStatusOffline, + LastConn: "2026-03-23 10:00:00", + SourceType: "router", + }); err != nil { + t.Fatalf("seed router node: %v", err) + } + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteNode, http.MethodDelete, "/api/v1/nodes/router-node", "router-node", nil) + if rec.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestHandleDeleteNodeRejectsRouterManagedNode(t *testing.T) { + env := newAPITestEnv(t) + env.seedNode(t, "managed-node", false) + + if err := env.database.InsertRouter(&db.Router{ + Name: "router-a", + Addr: "192.168.1.1", + SSHPort: 22, + SSHUser: "root", + APIKey: "api-key", + LANSubnet: "192.168.1.0/24", + DaemonMode: db.RouterDaemonModeRouterDaemon, + LinkedNodeAddr: "managed-node", + Status: "active", + }); err != nil { + t.Fatalf("seed router: %v", err) + } + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteNode, http.MethodDelete, "/api/v1/nodes/managed-node", "managed-node", nil) + if rec.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", rec.Code, rec.Body.String()) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index 0f22bb4..ae66c9e 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -90,6 +90,7 @@ func NewRouter(cfg *config.Config, database *db.Database, nodes *nodemanager.Man // Nodes r.Get("/api/v1/nodes", api.handleGetNodes) r.Get("/api/v1/nodes/{addr}", api.handleGetNode) + r.Delete("/api/v1/nodes/{addr}", api.handleDeleteNode) r.Put("/api/v1/nodes/{addr}/config", api.handleUpdateNodeConfig) r.Put("/api/v1/nodes/{addr}/tags", api.handleReplaceNodeTags) r.Post("/api/v1/nodes/{addr}/interception/enable", api.handleNodeAction(true, false)) diff --git a/internal/db/nodes.go b/internal/db/nodes.go index 3afa26f..dc62417 100644 --- a/internal/db/nodes.go +++ b/internal/db/nodes.go @@ -1,5 +1,7 @@ package db +import "fmt" + type Node struct { Addr string `json:"addr"` Hostname string `json:"hostname"` @@ -139,3 +141,58 @@ func (d *Database) GetNodeMode(addr string) (string, error) { } return mode, nil } + +func (d *Database) DeleteNode(addr string) error { + d.mu.Lock() + defer d.mu.Unlock() + + tx, err := d.db.Begin() + if err != nil { + return err + } + + statements := []struct { + query string + args []any + }{ + {query: "DELETE FROM template_attachments WHERE target_type = 'node' AND target_ref = ?", args: []any{addr}}, + {query: "DELETE FROM node_tags WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM node_template_sync WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM process_trust WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM rules WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM alerts WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM seen_flows WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM dns_domains WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM connections WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM hosts WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM procs WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM addrs WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM ports WHERE node = ?", args: []any{addr}}, + {query: "DELETE FROM users WHERE node = ?", args: []any{addr}}, + } + + for _, stmt := range statements { + if _, err := tx.Exec(stmt.query, stmt.args...); err != nil { + _ = tx.Rollback() + return err + } + } + + result, err := tx.Exec("DELETE FROM nodes WHERE addr = ?", addr) + if err != nil { + _ = tx.Rollback() + return err + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + _ = tx.Rollback() + return err + } + if rowsAffected == 0 { + _ = tx.Rollback() + return fmt.Errorf("node %s not found", addr) + } + + return tx.Commit() +} diff --git a/internal/router/provisioner.go b/internal/router/provisioner.go index 146eda7..5f2d43b 100644 --- a/internal/router/provisioner.go +++ b/internal/router/provisioner.go @@ -3,13 +3,16 @@ package router import ( "bytes" "context" + "errors" "fmt" "log" "net" "net/url" "os" "path/filepath" + "strconv" "strings" + "syscall" "time" "github.com/bilalbayram/opensnitch-web/internal/config" @@ -29,6 +32,19 @@ type sshRemoteClient struct { client *ssh.Client } +type sshDialError struct { + msg string + err error +} + +func (e *sshDialError) Error() string { + return e.msg +} + +func (e *sshDialError) Unwrap() error { + return e.err +} + func (c *sshRemoteClient) Close() error { return c.client.Close() } @@ -536,11 +552,15 @@ func (p *Provisioner) DeprovisionDaemon(ctx context.Context, addr string, sshPor // --- helpers --- func sshDial(addr string, port int, user, pass, key string) (remoteClient, error) { + target := net.JoinHostPort(addr, strconv.Itoa(port)) var authMethods []ssh.AuthMethod if key != "" { signer, err := ssh.ParsePrivateKey([]byte(key)) if err != nil { - return nil, fmt.Errorf("parse SSH key: %w", err) + return nil, &sshDialError{ + msg: "SSH private key is invalid or unsupported", + err: fmt.Errorf("parse SSH key: %w", err), + } } authMethods = append(authMethods, ssh.PublicKeys(signer)) } @@ -560,13 +580,73 @@ func sshDial(addr string, port int, user, pass, key string) (remoteClient, error User: user, Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Timeout: 10 * time.Second, } - client, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", addr, port), config) + + const timeout = 10 * time.Second + conn, err := net.DialTimeout("tcp", target, timeout) + if err != nil { + return nil, wrapSSHDialError(addr, port, user, err) + } + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("set SSH deadline: %w", err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, target, config) if err != nil { - return nil, err + _ = conn.Close() + return nil, wrapSSHDialError(addr, port, user, err) + } + if err := conn.SetDeadline(time.Time{}); err != nil { + _ = sshConn.Close() + return nil, fmt.Errorf("clear SSH deadline: %w", err) } - return &sshRemoteClient{client: client}, nil + + return &sshRemoteClient{client: ssh.NewClient(sshConn, chans, reqs)}, nil +} + +func wrapSSHDialError(addr string, port int, user string, err error) error { + if err == nil { + return nil + } + return &sshDialError{ + msg: explainSSHDialError(addr, port, user, err), + err: err, + } +} + +func explainSSHDialError(addr string, port int, user string, err error) string { + target := net.JoinHostPort(addr, strconv.Itoa(port)) + message := strings.TrimSpace(err.Error()) + lower := strings.ToLower(message) + + switch { + case isSSHTimeoutOrUnreachable(err, lower): + return fmt.Sprintf("router is offline or unreachable at %s", target) + case errors.Is(err, syscall.ECONNREFUSED) || strings.Contains(lower, "connection refused"): + return fmt.Sprintf("router is reachable at %s, but SSH is not accepting connections on that port", target) + case strings.Contains(lower, "unable to authenticate"): + return fmt.Sprintf("SSH authentication failed for %s@%s. Verify the SSH username and password or key", user, target) + case strings.Contains(lower, "connection reset by peer"), strings.Contains(lower, "broken pipe"), strings.Contains(lower, "unexpected packet"), strings.Contains(lower, "eof"): + return fmt.Sprintf("router responded at %s, but the SSH handshake did not complete", target) + default: + return fmt.Sprintf("SSH connection to %s failed: %s", target, message) + } +} + +func isSSHTimeoutOrUnreachable(err error, lower string) bool { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + return errors.Is(err, syscall.ETIMEDOUT) || + errors.Is(err, syscall.EHOSTUNREACH) || + errors.Is(err, syscall.ENETUNREACH) || + strings.Contains(lower, "i/o timeout") || + strings.Contains(lower, "no route to host") || + strings.Contains(lower, "host is down") || + strings.Contains(lower, "network is unreachable") } func runSSHCommand(client *ssh.Client, cmd string) (string, error) { diff --git a/internal/router/provisioner_test.go b/internal/router/provisioner_test.go index 5c9769f..7223e45 100644 --- a/internal/router/provisioner_test.go +++ b/internal/router/provisioner_test.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "syscall" "testing" "time" @@ -17,6 +18,12 @@ import ( "golang.org/x/crypto/ssh" ) +type timeoutNetError struct{} + +func (timeoutNetError) Error() string { return "i/o timeout" } +func (timeoutNetError) Timeout() bool { return true } +func (timeoutNetError) Temporary() bool { return true } + type fakeRemoteClient struct { outputs map[string]fakeRunResult writes map[string]string @@ -342,6 +349,36 @@ func TestSSHDialSupportsKeyboardInteractivePasswordAuth(t *testing.T) { } } +func TestExplainSSHDialErrorPasswordOnlyAuth(t *testing.T) { + err := errors.New("ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain") + + message := explainSSHDialError("192.168.1.1", 22, "root", err) + + if !strings.Contains(message, "SSH authentication failed for root@192.168.1.1:22") { + t.Fatalf("expected auth failure explanation, got %q", message) + } +} + +func TestExplainSSHDialErrorOfflineTimeout(t *testing.T) { + err := &net.OpError{Op: "dial", Net: "tcp", Err: timeoutNetError{}} + + message := explainSSHDialError("192.168.1.1", 22, "root", err) + + if message != "router is offline or unreachable at 192.168.1.1:22" { + t.Fatalf("unexpected message: %q", message) + } +} + +func TestExplainSSHDialErrorConnectionRefused(t *testing.T) { + err := &net.OpError{Op: "dial", Net: "tcp", Err: syscall.ECONNREFUSED} + + message := explainSSHDialError("192.168.1.1", 22, "root", err) + + if message != "router is reachable at 192.168.1.1:22, but SSH is not accepting connections on that port" { + t.Fatalf("unexpected message: %q", message) + } +} + func testSigner() (ssh.Signer, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 021f75b..f81851d 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -397,6 +397,10 @@ export const api = { getNodes: () => request("/nodes"), getNode: (addr: string) => request(`/nodes/${encodeURIComponent(addr)}`), + deleteNode: (addr: string) => + request<{ status: string }>(`/nodes/${encodeURIComponent(addr)}`, { + method: "DELETE", + }), updateNodeConfig: (addr: string, config: object) => request(`/nodes/${encodeURIComponent(addr)}/config`, { method: "PUT", diff --git a/web/src/pages/nodes.tsx b/web/src/pages/nodes.tsx index 5385680..8c8b5eb 100644 --- a/web/src/pages/nodes.tsx +++ b/web/src/pages/nodes.tsx @@ -137,6 +137,7 @@ export default function NodesPage() { Record >({}); const [routerBusy, setRouterBusy] = useState>({}); + const [deletingNodes, setDeletingNodes] = useState>({}); const fetchPageData = (force?: boolean) => { Promise.all([api.getNodes(), api.getRouters()]) @@ -297,6 +298,29 @@ export default function NodesPage() { } }; + const handleDeleteNode = async (node: NodeRecord) => { + const label = node.hostname || node.addr; + if (!window.confirm(`Delete ${label} and its stored data? This cannot be undone.`)) { + return; + } + + setDeletingNodes((prev) => ({ ...prev, [node.addr]: true })); + try { + await api.deleteNode(node.addr); + showStatus(node.addr, "Node deleted"); + fetchPageData(true); + } catch (e) { + console.error("Node delete failed:", e); + showStatus(node.addr, e instanceof Error ? e.message : "Node delete failed"); + } finally { + setDeletingNodes((prev) => { + const next = { ...prev }; + delete next[node.addr]; + return next; + }); + } + }; + const handleConnectRouter = async () => { setConnecting(true); setConnectError(""); @@ -1037,11 +1061,16 @@ export default function NodesPage() { ); })} - {genericNodes.map((node) => ( -
+ {genericNodes.map((node) => { + const deleting = Boolean(deletingNodes[node.addr]); + const canDelete = + !node.online && node.source_type !== "router" && !node.router_managed; + + return ( +
@@ -1055,6 +1084,21 @@ export default function NodesPage() {
+ {canDelete && ( + + )} - ))} + ); + })} {genericNodes.length === 0 && routerRuntimeNodes.length === 0 && (
From ada0915e40549e2eaa0c84d3cad55c6f42e3f605 Mon Sep 17 00:00:00 2001 From: bilalbayram Date: Mon, 23 Mar 2026 02:14:12 +0300 Subject: [PATCH 4/4] fix: add offline router deletion from nodes --- internal/api/handlers_routers.go | 42 +++++++++++++++++ internal/api/handlers_routers_test.go | 67 +++++++++++++++++++++++++++ internal/api/router.go | 1 + web/src/lib/api.ts | 4 ++ web/src/pages/nodes.tsx | 38 +++++++++++++++ 5 files changed, 152 insertions(+) diff --git a/internal/api/handlers_routers.go b/internal/api/handlers_routers.go index 20aef05..7a313fd 100644 --- a/internal/api/handlers_routers.go +++ b/internal/api/handlers_routers.go @@ -275,6 +275,48 @@ func (a *API) handleDisconnectRouter(w http.ResponseWriter, r *http.Request) { }) } +func (a *API) handleDeleteRouter(w http.ResponseWriter, r *http.Request) { + addr := routerAddrParam(r) + + rt, err := a.db.GetRouterByAddr(addr) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "router not found"}) + return + } + + nodeAddrs := []string{rt.Addr} + if strings.TrimSpace(rt.LinkedNodeAddr) != "" && rt.LinkedNodeAddr != rt.Addr { + nodeAddrs = append(nodeAddrs, rt.LinkedNodeAddr) + } + + for _, nodeAddr := range nodeAddrs { + node, err := a.db.GetNode(nodeAddr) + if err != nil { + continue + } + if routerOnlineFromLastConn(node.LastConn) { + writeJSON(w, http.StatusConflict, map[string]string{"error": "disconnect the router before deleting it"}) + return + } + } + + if err := a.db.DeleteRouter(rt.Addr); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + for _, nodeAddr := range nodeAddrs { + if _, err := a.db.GetNode(nodeAddr); err == nil { + if err := a.db.DeleteNode(nodeAddr); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + } + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + func (a *API) handleSuggestServerURL(w http.ResponseWriter, r *http.Request) { var req struct { RouterIP string `json:"router_ip"` diff --git a/internal/api/handlers_routers_test.go b/internal/api/handlers_routers_test.go index d644127..5ae5a92 100644 --- a/internal/api/handlers_routers_test.go +++ b/internal/api/handlers_routers_test.go @@ -336,3 +336,70 @@ func TestHandleDisconnectRouterSuccessRemovesState(t *testing.T) { t.Fatalf("expected node status offline, got %q", node.Status) } } + +func TestHandleDeleteRouterRemovesOfflineRouterAndNode(t *testing.T) { + env := newAPITestEnv(t) + seedRouterRecord(t, env, "router-a", time.Now().Add(-2*time.Minute)) + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteRouter, http.MethodDelete, "/api/v1/routers/router-a", "router-a", nil) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if _, err := env.database.GetRouterByAddr("router-a"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected router removed, got %v", err) + } + if _, err := env.database.GetNode("router-a"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected router node removed, got %v", err) + } +} + +func TestHandleDeleteRouterRejectsOnlineRouter(t *testing.T) { + env := newAPITestEnv(t) + seedRouterRecord(t, env, "router-a", time.Now().Add(-30*time.Second)) + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteRouter, http.MethodDelete, "/api/v1/routers/router-a", "router-a", nil) + if rec.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", rec.Code, rec.Body.String()) + } +} + +func TestHandleDeleteRouterRemovesManagedRuntimeNode(t *testing.T) { + env := newAPITestEnv(t) + + if err := env.database.InsertRouter(&db.Router{ + Name: "router-a", + Addr: "192.168.1.1", + SSHPort: 22, + SSHUser: "root", + APIKey: "api-key", + LANSubnet: "192.168.1.0/24", + DaemonMode: db.RouterDaemonModeRouterDaemon, + LinkedNodeAddr: "managed-node", + Status: "active", + }); err != nil { + t.Fatalf("seed router: %v", err) + } + + if err := env.database.UpsertNode(&db.Node{ + Addr: "managed-node", + Hostname: "managed-node", + DaemonVersion: "opensnitchd-router", + Status: "offline", + LastConn: time.Now().Add(-2 * time.Minute).Format("2006-01-02 15:04:05"), + }); err != nil { + t.Fatalf("seed managed node: %v", err) + } + + rec := performJSONRequestWithAddr(t, env.api.handleDeleteRouter, http.MethodDelete, "/api/v1/routers/192.168.1.1", "192.168.1.1", nil) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + if _, err := env.database.GetRouterByAddr("192.168.1.1"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected router removed, got %v", err) + } + if _, err := env.database.GetNode("managed-node"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected managed node removed, got %v", err) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index ae66c9e..546faf4 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -170,6 +170,7 @@ func NewRouter(cfg *config.Config, database *db.Database, nodes *nodemanager.Man r.Post("/api/v1/routers/suggest-url", api.handleSuggestServerURL) r.Post("/api/v1/routers/connect", api.handleConnectRouter) r.Get("/api/v1/routers", api.handleGetRouters) + r.Delete("/api/v1/routers/{addr}", api.handleDeleteRouter) r.Post("/api/v1/routers/{addr}/capabilities", api.handleRouterCapabilities) r.Post("/api/v1/routers/{addr}/upgrade", api.handleUpgradeRouter) r.Post("/api/v1/routers/{addr}/downgrade", api.handleDowngradeRouter) diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index f81851d..ade62b4 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -727,6 +727,10 @@ export const api = { body: JSON.stringify(params), }), getRouters: () => request("/routers"), + deleteRouter: (addr: string) => + request<{ status: string }>(`/routers/${encodeURIComponent(addr)}`, { + method: "DELETE", + }), disconnectRouter: (addr: string, sshPass: string) => request<{ status: string; steps: ProvisionStep[] }>( `/routers/${encodeURIComponent(addr)}/disconnect`, diff --git a/web/src/pages/nodes.tsx b/web/src/pages/nodes.tsx index 8c8b5eb..45214a2 100644 --- a/web/src/pages/nodes.tsx +++ b/web/src/pages/nodes.tsx @@ -382,6 +382,28 @@ export default function NodesPage() { } }; + const handleDeleteRouter = async (router: RouterRecord) => { + const label = router.name || router.addr; + if ( + !window.confirm( + `Delete offline router ${label} and its stored data? This removes the router locally without connecting over SSH.`, + ) + ) { + return; + } + + await withRouterBusy(router.addr, "deleting", async () => { + try { + await api.deleteRouter(router.addr); + showStatus(router.addr, "Router deleted"); + fetchPageData(true); + } catch (e) { + console.error("Router delete failed:", e); + showStatus(router.addr, e instanceof Error ? e.message : "Delete failed"); + } + }); + }; + const autoDetectSubnet = (ip: string) => { const parts = ip.split("."); if (parts.length === 4 && parts.every((p) => /^\d+$/.test(p))) { @@ -1014,7 +1036,23 @@ export default function NodesPage() { {busy === "disconnecting" ? : } Disconnect + {!online && ( + + )}
+ {!online && ( +
+ Delete removes the offline router entry locally when SSH disconnect is not possible. +
+ )}
{nodeControls && renderModeControls(runtimeNode)}