diff --git a/pkg/agent/state/state.go b/pkg/agent/state/state.go index c3e8057e3..e1b3d2017 100644 --- a/pkg/agent/state/state.go +++ b/pkg/agent/state/state.go @@ -56,6 +56,35 @@ func Phases() []Phase { return out } +// Ordinal returns the forward-progress ordering of a phase. +// Higher values represent later lifecycle stages. +// Returns 0 for terminal or special phases (stopped, error, suspended, stopping) +// where regression checks do not apply. +func (p Phase) Ordinal() int { + switch p { + case PhaseCreated: + return 1 + case PhaseProvisioning: + return 2 + case PhaseCloning: + return 3 + case PhaseStarting: + return 4 + case PhaseRunning: + return 5 + default: + return 0 + } +} + +// IsActivePhase reports whether this phase is part of the forward-progress +// lifecycle (created through running). Regression guards apply only between +// active phases — terminal phases (stopped, error) and special phases +// (suspended, stopping) are excluded. +func (p Phase) IsActivePhase() bool { + return p.Ordinal() > 0 +} + // String implements fmt.Stringer. func (p Phase) String() string { return string(p) } @@ -162,6 +191,18 @@ func (a Activity) IsTerminal() bool { return false } +// ImpliesRunning reports whether this activity implies the agent must be in +// PhaseRunning. Used for auto-correcting a stale pre-running phase when the +// agent is clearly active. +func (a Activity) ImpliesRunning() bool { + switch a { + case ActivityWorking, ActivityThinking, ActivityExecuting, + ActivityWaitingForInput, ActivityBlocked, ActivityCompleted: + return true + } + return false +} + // IsPlatformSet reports whether this activity is set by the platform (scheduler) // rather than by the agent itself. func (a Activity) IsPlatformSet() bool { diff --git a/pkg/agent/state/state_test.go b/pkg/agent/state/state_test.go index 30b6fb00f..4fa581b57 100644 --- a/pkg/agent/state/state_test.go +++ b/pkg/agent/state/state_test.go @@ -481,6 +481,70 @@ func TestPhasesEnumeration(t *testing.T) { } } +func TestPhaseOrdinal(t *testing.T) { + tests := []struct { + phase Phase + ordinal int + }{ + {PhaseCreated, 1}, + {PhaseProvisioning, 2}, + {PhaseCloning, 3}, + {PhaseStarting, 4}, + {PhaseRunning, 5}, + {PhaseSuspended, 0}, + {PhaseStopping, 0}, + {PhaseStopped, 0}, + {PhaseError, 0}, + } + for _, tt := range tests { + if got := tt.phase.Ordinal(); got != tt.ordinal { + t.Errorf("Phase(%q).Ordinal() = %d, want %d", tt.phase, got, tt.ordinal) + } + } + + // Verify strict ordering for forward-progress phases. + forward := []Phase{PhaseCreated, PhaseProvisioning, PhaseCloning, PhaseStarting, PhaseRunning} + for i := 1; i < len(forward); i++ { + if forward[i].Ordinal() <= forward[i-1].Ordinal() { + t.Errorf("Ordinal(%q)=%d should be > Ordinal(%q)=%d", + forward[i], forward[i].Ordinal(), forward[i-1], forward[i-1].Ordinal()) + } + } +} + +func TestPhaseIsActivePhase(t *testing.T) { + active := []Phase{PhaseCreated, PhaseProvisioning, PhaseCloning, PhaseStarting, PhaseRunning} + for _, p := range active { + if !p.IsActivePhase() { + t.Errorf("Phase(%q).IsActivePhase() = false, want true", p) + } + } + + notActive := []Phase{PhaseSuspended, PhaseStopping, PhaseStopped, PhaseError} + for _, p := range notActive { + if p.IsActivePhase() { + t.Errorf("Phase(%q).IsActivePhase() = true, want false", p) + } + } +} + +func TestActivityImpliesRunning(t *testing.T) { + implies := []Activity{ActivityWorking, ActivityThinking, ActivityExecuting, + ActivityWaitingForInput, ActivityBlocked, ActivityCompleted} + for _, a := range implies { + if !a.ImpliesRunning() { + t.Errorf("Activity(%q).ImpliesRunning() = false, want true", a) + } + } + + doesNotImply := []Activity{ActivityLimitsExceeded, ActivityStalled, ActivityOffline, ActivityCrashed} + for _, a := range doesNotImply { + if a.ImpliesRunning() { + t.Errorf("Activity(%q).ImpliesRunning() = true, want false", a) + } + } +} + func TestActivitiesEnumeration(t *testing.T) { activities := Activities() if len(activities) != 10 { diff --git a/pkg/hub/handlers.go b/pkg/hub/handlers.go index 4e1b45067..b4ad2f582 100644 --- a/pkg/hub/handlers.go +++ b/pkg/hub/handlers.go @@ -2847,6 +2847,16 @@ func (s *Server) updateAgentStatus(w http.ResponseWriter, r *http.Request, id st return } + // Guard against phase regressions and auto-correct phase from activity. + if status.Phase != "" || status.Activity != "" { + agent, err := s.store.GetAgent(ctx, id) + if err != nil { + writeErrorFromErr(w, err, "") + return + } + guardAgentPhaseTransition(agent, &status) + } + if err := s.store.UpdateAgentStatus(ctx, id, status); err != nil { writeErrorFromErr(w, err, "") return @@ -2862,6 +2872,37 @@ func (s *Server) updateAgentStatus(w http.ResponseWriter, r *http.Request, id st w.WriteHeader(http.StatusOK) } +// guardAgentPhaseTransition applies two guards to a status update: +// +// 1. Phase regression guard: rejects transitions that would move an agent +// backward in its forward-progress lifecycle (e.g. running → starting). +// 2. Activity-driven phase auto-correction: when an activity that implies the +// agent is running arrives but the phase is pre-running, auto-promotes the +// phase to running. +func guardAgentPhaseTransition(agent *store.Agent, status *store.AgentStatusUpdate) { + currentPhase := state.Phase(agent.Phase) + + // Guard 1: reject phase regressions within the forward-progress lifecycle. + if status.Phase != "" { + newPhase := state.Phase(status.Phase) + if currentPhase.IsActivePhase() && newPhase.IsActivePhase() && + newPhase.Ordinal() < currentPhase.Ordinal() { + status.Phase = "" + } + } + + // Guard 2: if an activity that implies the agent is running arrives + // without an explicit phase, and the current phase is pre-running, + // auto-correct the phase to running. + if status.Activity != "" && status.Phase == "" { + activity := state.Activity(status.Activity) + if activity.ImpliesRunning() && currentPhase.IsActivePhase() && + currentPhase != state.PhaseRunning { + status.Phase = string(state.PhaseRunning) + } + } +} + func (s *Server) handleAgentLifecycle(w http.ResponseWriter, r *http.Request, id, action string) { ctx := r.Context() @@ -6094,8 +6135,17 @@ func (s *Server) handleBrokerHeartbeat(w http.ResponseWriter, r *http.Request, i statusUpdate.Message = agentHB.Message } } else { - // Structured path: broker sent Phase/Activity directly - statusUpdate.Phase = agentHB.Phase + // Structured path: broker sent Phase/Activity directly. + // Guard against phase regressions: stale heartbeat data + // must not move a running agent back to starting/etc. + hbPhase := state.Phase(agentHB.Phase) + curPhase := state.Phase(agent.Phase) + if curPhase.IsActivePhase() && hbPhase.IsActivePhase() && + hbPhase.Ordinal() < curPhase.Ordinal() { + // Suppress the regression — keep the hub's phase. + } else { + statusUpdate.Phase = agentHB.Phase + } // Only propagate Activity when it differs from the stored // value. Heartbeats always report the current activity, but // repeating the same value would refresh last_activity_event diff --git a/pkg/hub/handlers_agent_test.go b/pkg/hub/handlers_agent_test.go index e445ac367..26e85529e 100644 --- a/pkg/hub/handlers_agent_test.go +++ b/pkg/hub/handlers_agent_test.go @@ -4643,3 +4643,126 @@ func TestHandleProjectAgentExec_DispatchesToRuntimeBroker(t *testing.T) { assert.Equal(t, "terminal output", resp.Output) assert.Equal(t, 0, resp.ExitCode) } + +func TestAgentStatusUpdate_RejectsPhaseRegression(t *testing.T) { + srv, s := testServer(t) + ctx := context.Background() + + project := &store.Project{ID: "proj-regress", Name: "Regression Project", Slug: "regress-project"} + require.NoError(t, s.CreateProject(ctx, project)) + + agent := &store.Agent{ + ID: "agent-regress", Slug: "regress-slug", Name: "Regression Agent", + ProjectID: project.ID, Phase: string(state.PhaseRunning), + Activity: string(state.ActivityExecuting), + } + require.NoError(t, s.CreateAgent(ctx, agent)) + + tokenSvc := srv.GetAgentTokenService() + require.NotNil(t, tokenSvc) + token, err := tokenSvc.GenerateAgentToken(agent.ID, project.ID, []AgentTokenScope{ScopeAgentStatusUpdate}, nil) + require.NoError(t, err) + + // Attempt to regress phase from running → starting (as a spurious session would) + status := store.AgentStatusUpdate{Phase: string(state.PhaseStarting)} + body, _ := json.Marshal(status) + req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/"+agent.ID+"/status", bytes.NewReader(body)) + req.Header.Set("X-Scion-Agent-Token", token) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + + // Phase should remain running — regression was rejected + updated, err := s.GetAgent(ctx, agent.ID) + require.NoError(t, err) + assert.Equal(t, string(state.PhaseRunning), updated.Phase, + "phase regression from running to starting should be rejected") + assert.Equal(t, string(state.ActivityExecuting), updated.Activity, + "activity should be preserved when phase regression is rejected") +} + +func TestAgentStatusUpdate_ActivityAutoCorrectsPhase(t *testing.T) { + srv, s := testServer(t) + ctx := context.Background() + + project := &store.Project{ID: "proj-autocorrect", Name: "AutoCorrect Project", Slug: "autocorrect-project"} + require.NoError(t, s.CreateProject(ctx, project)) + + agent := &store.Agent{ + ID: "agent-autocorrect", Slug: "autocorrect-slug", Name: "AutoCorrect Agent", + ProjectID: project.ID, Phase: string(state.PhaseStarting), + } + require.NoError(t, s.CreateAgent(ctx, agent)) + + tokenSvc := srv.GetAgentTokenService() + require.NotNil(t, tokenSvc) + token, err := tokenSvc.GenerateAgentToken(agent.ID, project.ID, []AgentTokenScope{ScopeAgentStatusUpdate}, nil) + require.NoError(t, err) + + // Send an activity-only update (working) while phase is starting. + // This should auto-correct the phase to running. + status := store.AgentStatusUpdate{Activity: string(state.ActivityWorking)} + body, _ := json.Marshal(status) + req := httptest.NewRequest(http.MethodPost, "/api/v1/agents/"+agent.ID+"/status", bytes.NewReader(body)) + req.Header.Set("X-Scion-Agent-Token", token) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + srv.Handler().ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + + // Phase should auto-correct to running + updated, err := s.GetAgent(ctx, agent.ID) + require.NoError(t, err) + assert.Equal(t, string(state.PhaseRunning), updated.Phase, + "activity=working should auto-correct phase from starting to running") + assert.Equal(t, string(state.ActivityWorking), updated.Activity) +} + +func TestBrokerHeartbeat_RejectsPhaseRegression(t *testing.T) { + srv, s := testServer(t) + ctx := context.Background() + + project := &store.Project{ID: "proj-hb-regress", Name: "HB Regression Project", Slug: "hb-regress-project"} + require.NoError(t, s.CreateProject(ctx, project)) + + broker := &store.RuntimeBroker{ + ID: "broker-hb-regress", Name: "HB Regression Broker", Slug: "hb-regress-broker", + Status: store.BrokerStatusOnline, + } + require.NoError(t, s.CreateRuntimeBroker(ctx, broker)) + + agent := &store.Agent{ + ID: "agent-hb-regress", Slug: "hb-regress-slug", Name: "HB Regression Agent", + ProjectID: project.ID, RuntimeBrokerID: broker.ID, + Phase: string(state.PhaseRunning), + Activity: string(state.ActivityWorking), + } + require.NoError(t, s.CreateAgent(ctx, agent)) + + // Send a heartbeat with stale phase=starting (as if agent-info.json was + // corrupted by a spurious session's pre-start hook) + hb := brokerHeartbeatRequest{ + Status: "online", + Projects: []brokerProjectHeartbeat{{ + ProjectID: project.ID, + AgentCount: 1, + Agents: []brokerAgentHeartbeat{{ + Slug: agent.Slug, + Phase: string(state.PhaseStarting), + Activity: string(state.ActivityWorking), + ContainerStatus: "Up 10 minutes", + }}, + }}, + } + rec := doRequest(t, srv, http.MethodPost, "/api/v1/runtime-brokers/"+broker.ID+"/heartbeat", hb) + assert.Equal(t, http.StatusOK, rec.Code) + + // Phase should remain running — heartbeat regression was rejected + updated, err := s.GetAgent(ctx, agent.ID) + require.NoError(t, err) + assert.Equal(t, string(state.PhaseRunning), updated.Phase, + "heartbeat should not regress phase from running to starting") +}