diff --git a/pkg/sciontool/metadata/server.go b/pkg/sciontool/metadata/server.go index 06f6acf53..b9f231985 100644 --- a/pkg/sciontool/metadata/server.go +++ b/pkg/sciontool/metadata/server.go @@ -20,6 +20,8 @@ package metadata import ( "bytes" "context" + "crypto/rand" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -27,6 +29,7 @@ import ( "net" "net/http" "os" + "path/filepath" "strings" "sync" "syscall" @@ -142,6 +145,9 @@ type Server struct { healthMu sync.Mutex restartCount int abandoned bool + + shutdownToken string + shutdownTokenPath string } // authToken returns the current auth token, preferring the dynamic TokenFunc @@ -200,10 +206,6 @@ func (s *Server) Start(ctx context.Context) error { s.cancel = cancel addr := fmt.Sprintf("127.0.0.1:%d", s.config.Port) - s.srv = &http.Server{ - Addr: addr, - Handler: s.buildMux(), - } ln, err := net.Listen("tcp", addr) if err != nil && errors.Is(err, syscall.EADDRINUSE) { @@ -241,6 +243,16 @@ func (s *Server) Start(ctx context.Context) error { return fmt.Errorf("metadata server listen: %w", err) } + if err := s.ensureShutdownToken(); err != nil { + cancel() + ln.Close() + return fmt.Errorf("metadata server shutdown token: %w", err) + } + s.srv = &http.Server{ + Addr: addr, + Handler: s.buildMux(), + } + // Track this server so a future Start() can forcefully close it. activeServerMu.Lock() activeServer = s @@ -352,6 +364,11 @@ func (s *Server) Stop() { s.srv.Shutdown(shutdownCtx) shutdownCancel() } + if s.shutdownTokenPath != "" { + if err := os.Remove(s.shutdownTokenPath); err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debug("Failed to remove metadata shutdown token file %s: %v", s.shutdownTokenPath, err) + } + } if s.cancel != nil { s.cancel() @@ -369,6 +386,12 @@ func (s *Server) shutdownExisting() { return } req.Header.Set("Metadata-Flavor", "Google") + token, err := os.ReadFile(shutdownTokenPath(s.config.Port)) + if err != nil { + log.Debug("Could not read metadata shutdown token for port %d: %v", s.config.Port, err) + return + } + req.Header.Set("X-Scion-Shutdown-Token", strings.TrimSpace(string(token))) resp, err := client.Do(req) if err != nil { log.Debug("Could not reach existing metadata server for shutdown: %v", err) @@ -385,6 +408,10 @@ func (s *Server) handleShutdown(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } + if s.shutdownToken == "" || r.Header.Get("X-Scion-Shutdown-Token") != s.shutdownToken { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } log.Info("Shutdown requested via /_scion/shutdown, stopping metadata server") w.WriteHeader(http.StatusOK) fmt.Fprint(w, "shutting down") @@ -394,6 +421,39 @@ func (s *Server) handleShutdown(w http.ResponseWriter, r *http.Request) { }() } +func shutdownTokenPath(port int) string { + return filepath.Join(os.TempDir(), fmt.Sprintf("scion-metadata-shutdown-%d.token", port)) +} + +func (s *Server) ensureShutdownToken() error { + if s.shutdownToken != "" { + return nil + } + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return err + } + s.shutdownToken = hex.EncodeToString(tokenBytes) + s.shutdownTokenPath = shutdownTokenPath(s.config.Port) + return writeShutdownToken(s.shutdownTokenPath, s.shutdownToken) +} + +func writeShutdownToken(path, token string) error { + if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_EXCL|syscall.O_NOFOLLOW, 0600) + if err != nil { + return err + } + defer f.Close() + if _, err := f.WriteString(token + "\n"); err != nil { + _ = os.Remove(path) + return err + } + return nil +} + func (s *Server) probeHealth() bool { client := &http.Client{Timeout: healthCheckTimeout} resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/", s.config.Port)) diff --git a/pkg/sciontool/metadata/server_test.go b/pkg/sciontool/metadata/server_test.go index a0deecf53..9318cec5b 100644 --- a/pkg/sciontool/metadata/server_test.go +++ b/pkg/sciontool/metadata/server_test.go @@ -22,6 +22,8 @@ import ( "net" "net/http" "net/http/httptest" + "os" + "strings" "sync" "sync/atomic" "testing" @@ -762,13 +764,31 @@ func TestMetadataServer_ShutdownEndpoint(t *testing.T) { t.Fatalf("expected 403 without Metadata-Flavor, got %d", resp.StatusCode) } - // POST with Metadata-Flavor should succeed and shut down + // POST with Metadata-Flavor but no shutdown token should be rejected req, _ = http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/_scion/shutdown", port), nil) req.Header.Set("Metadata-Flavor", "Google") resp, err = http.DefaultClient.Do(req) if err != nil { t.Fatal(err) } + resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("expected 403 without shutdown token, got %d", resp.StatusCode) + } + + token, err := os.ReadFile(shutdownTokenPath(port)) + if err != nil { + t.Fatal(err) + } + + // POST with Metadata-Flavor and shutdown token should succeed and shut down + req, _ = http.NewRequest(http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/_scion/shutdown", port), nil) + req.Header.Set("Metadata-Flavor", "Google") + req.Header.Set("X-Scion-Shutdown-Token", strings.TrimSpace(string(token))) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } body, _ := io.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -833,3 +853,47 @@ func TestMetadataServer_StartReclaimsPort(t *testing.T) { t.Fatalf("expected new-project from replacement server, got %q", body) } } + +func TestMetadataServer_StartReclaimsPortViaShutdownEndpoint(t *testing.T) { + port := freePort(t) + + srv1 := New(Config{ + Mode: "block", + Port: port, + ProjectID: "old-project", + }) + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + + if err := srv1.Start(ctx1); err != nil { + t.Fatal(err) + } + defer srv1.Stop() + time.Sleep(50 * time.Millisecond) + + activeServerMu.Lock() + activeServer = nil + activeServerMu.Unlock() + + srv2 := New(Config{ + Mode: "block", + Port: port, + ProjectID: "new-project", + }) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + if err := srv2.Start(ctx2); err != nil { + t.Fatalf("second Start() should reclaim port via shutdown endpoint: %v", err) + } + defer srv2.Stop() + time.Sleep(50 * time.Millisecond) + + resp, body := metadataGet(t, port, "/computeMetadata/v1/project/project-id") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if body != "new-project" { + t.Fatalf("expected new-project from replacement server, got %q", body) + } +}