diff --git a/common/config/config.go b/common/config/config.go index 904a2f7eb8e..a9e510d8ad0 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -658,8 +658,9 @@ type ( } LeaderProcess struct { - Period time.Duration `yaml:"period"` - HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + Period time.Duration `yaml:"period"` + HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + ShardStatsTTL time.Duration `yaml:"shardStatsTTL"` } ) diff --git a/common/log/tag/tags.go b/common/log/tag/tags.go index b0b5f90d38e..6042d20e930 100644 --- a/common/log/tag/tags.go +++ b/common/log/tag/tags.go @@ -1185,9 +1185,18 @@ func ShardKey(shardKey string) Tag { return newStringTag("shard-key", shardKey) } +func AggregateLoad(load float64) Tag { + return newFloat64Tag("aggregated-load", load) +} + +func AssignedCount(count int) Tag { + return newInt("assigned-count", count) +} + func ShardStatus(status string) Tag { return newStringTag("shard-status", status) } + func ShardLoad(load string) Tag { return newStringTag("shard-load", load) } diff --git a/config/development.yaml b/config/development.yaml index d485c7ac4ba..1f34cb3d2a0 100644 --- a/config/development.yaml +++ b/config/development.yaml @@ -186,3 +186,4 @@ shardDistribution: process: period: 1s heartbeatTTL: 2s + shardStatsTTL: 60s diff --git a/service/sharddistributor/config/config.go b/service/sharddistributor/config/config.go index 79106ffc635..4ae80e489fc 100644 --- a/service/sharddistributor/config/config.go +++ b/service/sharddistributor/config/config.go @@ -79,8 +79,9 @@ type ( } LeaderProcess struct { - Period time.Duration `yaml:"period"` - HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + Period time.Duration `yaml:"period"` + HeartbeatTTL time.Duration `yaml:"heartbeatTTL"` + ShardStatsTTL time.Duration `yaml:"shardStatsTTL"` } ) @@ -97,6 +98,10 @@ const ( MigrationModeONBOARDED = "onboarded" ) +const ( + DefaultShardStatsTTL = time.Minute +) + // ConfigMode maps string migration mode values to types.MigrationMode var ConfigMode = map[string]types.MigrationMode{ MigrationModeINVALID: types.MigrationModeINVALID, diff --git a/service/sharddistributor/handler/handler.go b/service/sharddistributor/handler/handler.go index 08fe6b1680c..ab053160463 100644 --- a/service/sharddistributor/handler/handler.go +++ b/service/sharddistributor/handler/handler.go @@ -31,6 +31,7 @@ import ( "sync" "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/sharddistributor/config" "github.com/uber/cadence/service/sharddistributor/store" @@ -113,34 +114,99 @@ func (h *handlerImpl) GetShardOwner(ctx context.Context, request *types.GetShard func (h *handlerImpl) assignEphemeralShard(ctx context.Context, namespace string, shardID string) (*types.GetShardOwnerResponse, error) { - // Get the current state of the namespace and find the executor with the least assigned shards + // Get the current state of the namespace and evaluate executor load to choose a placement target. state, err := h.storage.GetState(ctx, namespace) if err != nil { return nil, fmt.Errorf("get state: %w", err) } - var executor string - minAssignedShards := math.MaxInt - - for assignedExecutor, assignment := range state.ShardAssignments { - if len(assignment.AssignedShards) < minAssignedShards { - minAssignedShards = len(assignment.AssignedShards) - executor = assignedExecutor - } + executorID, aggregatedLoad, assignedCount, err := pickLeastLoadedExecutor(state) + if err != nil { + h.logger.Error( + "no eligible executor found for ephemeral assignment", + tag.ShardNamespace(namespace), + tag.ShardKey(shardID), + tag.Error(err), + ) + return nil, err } + h.logger.Info( + "selected executor for ephemeral shard assignment", + tag.AggregateLoad(aggregatedLoad), + tag.AssignedCount(assignedCount), + tag.ShardNamespace(namespace), + tag.ShardKey(shardID), + tag.ShardExecutor(executorID), + ) + // Assign the shard to the executor with the least assigned shards - err = h.storage.AssignShard(ctx, namespace, shardID, executor) + err = h.storage.AssignShard(ctx, namespace, shardID, executorID) if err != nil { + h.logger.Error( + "failed to assign ephemeral shard", + tag.ShardNamespace(namespace), + tag.ShardKey(shardID), + tag.ShardExecutor(executorID), + tag.Error(err), + ) return nil, fmt.Errorf("assign ephemeral shard: %w", err) } return &types.GetShardOwnerResponse{ - Owner: executor, + Owner: executorID, Namespace: namespace, }, nil } +// pickLeastLoadedExecutor returns the ACTIVE executor with the minimal aggregated smoothed load. +// Ties are broken by fewer assigned shards. +func pickLeastLoadedExecutor(state *store.NamespaceState) (executorID string, aggregatedLoad float64, assignedCount int, err error) { + if state == nil { + return "", 0, 0, fmt.Errorf("namespace state is nil") + } + if len(state.ShardAssignments) == 0 { + return "", 0, 0, fmt.Errorf("namespace state has no executors") + } + + var chosenID string + var chosenAggregatedLoad float64 + var chosenAssignedCount int + minAggregatedLoad := math.MaxFloat64 + minAssignedShards := math.MaxInt + + for candidate, assignment := range state.ShardAssignments { + executorState, ok := state.Executors[candidate] + if !ok || executorState.Status != types.ExecutorStatusACTIVE { + continue + } + + aggregated := 0.0 + for shard := range assignment.AssignedShards { + if stats, ok := state.ShardStats[shard]; ok { + if !math.IsNaN(stats.SmoothedLoad) && !math.IsInf(stats.SmoothedLoad, 0) { + aggregated += stats.SmoothedLoad + } + } + } + + count := len(assignment.AssignedShards) + if aggregated < minAggregatedLoad || (aggregated == minAggregatedLoad && count < minAssignedShards) { + minAggregatedLoad = aggregated + minAssignedShards = count + chosenID = candidate + chosenAggregatedLoad = aggregated + chosenAssignedCount = count + } + } + + if chosenID == "" { + return "", 0, 0, fmt.Errorf("no active executors available") + } + + return chosenID, chosenAggregatedLoad, chosenAssignedCount, nil +} + func (h *handlerImpl) WatchNamespaceState(request *types.WatchNamespaceStateRequest, server WatchNamespaceStateServer) error { h.startWG.Wait() diff --git a/service/sharddistributor/handler/handler_test.go b/service/sharddistributor/handler/handler_test.go index a683168a8be..a5ca98c6969 100644 --- a/service/sharddistributor/handler/handler_test.go +++ b/service/sharddistributor/handler/handler_test.go @@ -138,6 +138,10 @@ func TestGetShardOwner(t *testing.T) { setupMocks: func(mockStore *store.MockStore) { mockStore.EXPECT().GetShardOwner(gomock.Any(), _testNamespaceEphemeral, "NON-EXISTING-SHARD").Return(nil, store.ErrShardNotFound) mockStore.EXPECT().GetState(gomock.Any(), _testNamespaceEphemeral).Return(&store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "owner1": {Status: types.ExecutorStatusACTIVE}, + "owner2": {Status: types.ExecutorStatusACTIVE}, + }, ShardAssignments: map[string]store.AssignedState{ "owner1": { AssignedShards: map[string]*types.ShardAssignment{ @@ -181,12 +185,82 @@ func TestGetShardOwner(t *testing.T) { setupMocks: func(mockStore *store.MockStore) { mockStore.EXPECT().GetShardOwner(gomock.Any(), _testNamespaceEphemeral, "NON-EXISTING-SHARD").Return(nil, store.ErrShardNotFound) mockStore.EXPECT().GetState(gomock.Any(), _testNamespaceEphemeral).Return(&store.NamespaceState{ - ShardAssignments: map[string]store.AssignedState{"owner1": {AssignedShards: map[string]*types.ShardAssignment{}}}}, nil) + Executors: map[string]store.HeartbeatState{ + "owner1": {Status: types.ExecutorStatusACTIVE}, + }, + ShardAssignments: map[string]store.AssignedState{ + "owner1": {AssignedShards: map[string]*types.ShardAssignment{}}, + }, + }, nil) mockStore.EXPECT().AssignShard(gomock.Any(), _testNamespaceEphemeral, "NON-EXISTING-SHARD", "owner1").Return(errors.New("assign shard failure")) }, expectedError: true, expectedErrMsg: "assign shard failure", }, + { + name: "ShardNotFound_Ephemeral_LoadBased", + request: &types.GetShardOwnerRequest{ + Namespace: _testNamespaceEphemeral, + ShardKey: "new-shard", + }, + setupMocks: func(mockStore *store.MockStore) { + mockStore.EXPECT().GetShardOwner(gomock.Any(), _testNamespaceEphemeral, "new-shard").Return(nil, store.ErrShardNotFound) + mockStore.EXPECT().GetState(gomock.Any(), _testNamespaceEphemeral).Return(&store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "owner1": {Status: types.ExecutorStatusACTIVE}, + "owner2": {Status: types.ExecutorStatusACTIVE}, + }, + ShardAssignments: map[string]store.AssignedState{ + "owner1": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard1": {Status: types.AssignmentStatusREADY}, + "shard2": {Status: types.AssignmentStatusREADY}, + }, + }, + "owner2": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard3": {Status: types.AssignmentStatusREADY}, + }, + }, + }, + ShardStats: map[string]store.ShardStatistics{ + "shard1": {SmoothedLoad: 2.5}, + "shard2": {SmoothedLoad: 1.0}, + "shard3": {SmoothedLoad: 0.5}, + }, + }, nil) + // owner1 total load: 2.5 + 1.0 = 3.5 + // owner2 total load: 0.5 + // Should pick owner2 (least loaded) + mockStore.EXPECT().AssignShard(gomock.Any(), _testNamespaceEphemeral, "new-shard", "owner2").Return(nil) + }, + expectedOwner: "owner2", + expectedError: false, + }, + { + name: "ShardNotFound_Ephemeral_AllExecutorsDraining", + request: &types.GetShardOwnerRequest{ + Namespace: _testNamespaceEphemeral, + ShardKey: "new-shard", + }, + setupMocks: func(mockStore *store.MockStore) { + mockStore.EXPECT().GetShardOwner(gomock.Any(), _testNamespaceEphemeral, "new-shard").Return(nil, store.ErrShardNotFound) + mockStore.EXPECT().GetState(gomock.Any(), _testNamespaceEphemeral).Return(&store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "owner1": {Status: types.ExecutorStatusDRAINING}, + }, + ShardAssignments: map[string]store.AssignedState{ + "owner1": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard1": {Status: types.AssignmentStatusREADY}, + }, + }, + }, + }, nil) + }, + expectedError: true, + expectedErrMsg: "no active executors", + }, } for _, tt := range tests { @@ -218,6 +292,128 @@ func TestGetShardOwner(t *testing.T) { } } +func TestPickLeastLoadedExecutor(t *testing.T) { + tests := []struct { + name string + state *store.NamespaceState + expectedOwner string + expectedLoad float64 + expectedCount int + expectedError bool + }{ + { + name: "SelectsLeastLoaded", + state: &store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "exec1": {Status: types.ExecutorStatusACTIVE}, + "exec2": {Status: types.ExecutorStatusACTIVE}, + }, + ShardAssignments: map[string]store.AssignedState{ + "exec1": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard1": {}, + "shard2": {}, + "shard3": {}, + "shard4": {}, + }, + }, + "exec2": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard4": {}, + "shard5": {}, + }, + }, + }, + ShardStats: map[string]store.ShardStatistics{ + "shard1": {SmoothedLoad: 1.0}, + "shard2": {SmoothedLoad: 2.0}, + "shard3": {SmoothedLoad: 0.5}, + "shard4": {SmoothedLoad: 0.25}, + "shard5": {SmoothedLoad: 1.0}, + }, + }, + expectedOwner: "exec2", + expectedLoad: 1.25, + expectedCount: 2, + expectedError: false, + }, + { + name: "SkipsNonActiveExecutors", + state: &store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "exec1": {Status: types.ExecutorStatusDRAINING}, + "exec2": {Status: types.ExecutorStatusACTIVE}, + }, + ShardAssignments: map[string]store.AssignedState{ + "exec1": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard1": {}, + }, + }, + "exec2": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard2": {}, + "shard3": {}, + }, + }, + }, + ShardStats: map[string]store.ShardStatistics{ + "shard1": {SmoothedLoad: 0.1}, + "shard2": {SmoothedLoad: 1.0}, + "shard3": {SmoothedLoad: 2.0}, + }, + }, + expectedOwner: "exec2", + expectedLoad: 3.0, + expectedCount: 2, + expectedError: false, + }, + { + name: "SelectsLeastLoaded_NoExecutors", + state: &store.NamespaceState{}, + expectedError: true, + }, + { + name: "SelectsLeastLoaded_NoActiveExecutors", + state: &store.NamespaceState{ + Executors: map[string]store.HeartbeatState{ + "exec1": {Status: types.ExecutorStatusDRAINING}, + }, + ShardAssignments: map[string]store.AssignedState{ + "exec1": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard1": {}, + }, + }, + }, + }, + expectedError: true, + }, + { + name: "SelectsLeastLoaded_NoShards", + state: &store.NamespaceState{ + ShardAssignments: map[string]store.AssignedState{}, + }, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + owner, load, count, err := pickLeastLoadedExecutor(tt.state) + if tt.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedOwner, owner) + require.Equal(t, tt.expectedLoad, load) + require.Equal(t, tt.expectedCount, count) + + } + }) + } +} + func TestWatchNamespaceState(t *testing.T) { ctrl := gomock.NewController(t) logger := testlogger.New(t) diff --git a/service/sharddistributor/leader/process/processor.go b/service/sharddistributor/leader/process/processor.go index 901e6822ef7..9179357f07c 100644 --- a/service/sharddistributor/leader/process/processor.go +++ b/service/sharddistributor/leader/process/processor.go @@ -76,12 +76,15 @@ func NewProcessorFactory( timeSource clock.TimeSource, cfg config.ShardDistribution, ) Factory { - if cfg.Process.Period == 0 { + if cfg.Process.Period <= 0 { cfg.Process.Period = _defaultPeriod } - if cfg.Process.HeartbeatTTL == 0 { + if cfg.Process.HeartbeatTTL <= 0 { cfg.Process.HeartbeatTTL = _defaultHearbeatTTL } + if cfg.Process.ShardStatsTTL <= 0 { + cfg.Process.ShardStatsTTL = config.DefaultShardStatsTTL + } return &processorFactory{ logger: logger, @@ -237,7 +240,7 @@ func (p *namespaceProcessor) runShardStatsCleanupLoop(ctx context.Context) { continue } staleShardStats := p.identifyStaleShardStats(namespaceState) - if len(staleShardStats) > 0 { + if len(staleShardStats) == 0 { // No stale shard stats to delete continue } @@ -267,7 +270,7 @@ func (p *namespaceProcessor) identifyStaleExecutors(namespaceState *store.Namesp func (p *namespaceProcessor) identifyStaleShardStats(namespaceState *store.NamespaceState) []string { activeShards := make(map[string]struct{}) now := p.timeSource.Now().Unix() - shardStatsTTL := int64(p.cfg.HeartbeatTTL.Seconds()) + shardStatsTTL := int64(p.cfg.ShardStatsTTL.Seconds()) // 1. build set of active executors diff --git a/service/sharddistributor/leader/process/processor_test.go b/service/sharddistributor/leader/process/processor_test.go index 4027e9a496b..206c0ebf3c3 100644 --- a/service/sharddistributor/leader/process/processor_test.go +++ b/service/sharddistributor/leader/process/processor_test.go @@ -44,8 +44,9 @@ func setupProcessorTest(t *testing.T, namespaceType string) *testDependencies { mockedClock, config.ShardDistribution{ Process: config.LeaderProcess{ - Period: time.Second, - HeartbeatTTL: time.Second, + Period: time.Second, + HeartbeatTTL: time.Second, + ShardStatsTTL: 10 * time.Second, }, }, ), @@ -259,7 +260,7 @@ func TestCleanupStaleShardStats(t *testing.T) { shardStats := map[string]store.ShardStatistics{ "shard-1": {SmoothedLoad: 1.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, "shard-2": {SmoothedLoad: 2.0, LastUpdateTime: now.Unix(), LastMoveTime: now.Unix()}, - "shard-3": {SmoothedLoad: 3.0, LastUpdateTime: now.Add(-2 * time.Second).Unix(), LastMoveTime: now.Add(-2 * time.Second).Unix()}, + "shard-3": {SmoothedLoad: 3.0, LastUpdateTime: now.Add(-11 * time.Second).Unix(), LastMoveTime: now.Add(-11 * time.Second).Unix()}, } namespaceState := &store.NamespaceState{ diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore.go b/service/sharddistributor/store/etcd/executorstore/etcdstore.go index 9bd2b465d61..f509801ca16 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "strconv" "time" @@ -35,8 +36,15 @@ type executorStoreImpl struct { logger log.Logger shardCache *shardcache.ShardToExecutorCache timeSource clock.TimeSource + // Max interval (seconds) before we force a shard-stat persist. + maxStatsPersistIntervalSeconds int64 } +// Constants for gating shard statistics writes to reduce etcd load. +const ( + shardStatsEpsilon = 0.05 +) + // shardStatisticsUpdate holds the staged statistics for a shard so we can write them // to etcd after the main AssignShards transaction commits. type shardStatisticsUpdate struct { @@ -88,12 +96,18 @@ func NewStore(p ExecutorStoreParams) (store.Store, error) { timeSource = clock.NewRealTimeSource() } + shardStatsTTL := p.Cfg.Process.ShardStatsTTL + if shardStatsTTL <= 0 { + shardStatsTTL = config.DefaultShardStatsTTL + } + store := &executorStoreImpl{ - client: etcdClient, - prefix: etcdCfg.Prefix, - logger: p.Logger, - shardCache: shardCache, - timeSource: timeSource, + client: etcdClient, + prefix: etcdCfg.Prefix, + logger: p.Logger, + shardCache: shardCache, + timeSource: timeSource, + maxStatsPersistIntervalSeconds: deriveStatsPersistInterval(shardStatsTTL), } p.Lifecycle.Append(fx.StartStopHook(store.Start, store.Stop)) @@ -153,9 +167,137 @@ func (s *executorStoreImpl) RecordHeartbeat(ctx context.Context, namespace, exec if err != nil { return fmt.Errorf("record heartbeat: %w", err) } + + s.recordShardStatistics(ctx, namespace, executorID, request.ReportedShards) + return nil } +func deriveStatsPersistInterval(shardStatsTTL time.Duration) int64 { + ttlSeconds := int64(shardStatsTTL.Seconds()) + if ttlSeconds <= 1 { + return 1 + } + return ttlSeconds - 1 +} + +func (s *executorStoreImpl) recordShardStatistics(ctx context.Context, namespace, executorID string, reported map[string]*types.ShardStatusReport) { + if len(reported) == 0 { + return + } + + now := s.timeSource.Now().Unix() + + for shardID, report := range reported { + if report == nil { + s.logger.Warn("empty report; skipping EWMA update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + ) + continue + } + + load := report.ShardLoad + if math.IsNaN(load) || math.IsInf(load, 0) { + s.logger.Warn( + "invalid shard load reported; skipping EWMA update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + ) + continue + } + + shardStatsKey, err := etcdkeys.BuildShardKey(s.prefix, namespace, shardID, etcdkeys.ShardStatisticsKey) + if err != nil { + s.logger.Warn( + "failed to build shard statistics key from heartbeat", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + + statsResp, err := s.client.Get(ctx, shardStatsKey) + if err != nil { + s.logger.Warn( + "failed to read shard statistics for heartbeat update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + + var stats store.ShardStatistics + if len(statsResp.Kvs) > 0 { + err := common.DecompressAndUnmarshal(statsResp.Kvs[0].Value, &stats) + if err != nil { + s.logger.Warn( + "failed to unmarshal shard statistics for heartbeat update", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + } + + // Update smoothed load via EWMA. + prevSmoothed := stats.SmoothedLoad + prevUpdate := stats.LastUpdateTime + newSmoothed := ewmaSmoothedLoad(prevSmoothed, load, prevUpdate, now) + + // Decide whether to persist this update. We always persist if this is the + // first observation (prevUpdate == 0). Otherwise, if the change is small + // and the previous persist is recent, skip the write to reduce etcd load. + shouldPersist := true + if prevUpdate > 0 { + age := now - prevUpdate + delta := math.Abs(newSmoothed - prevSmoothed) + if delta < shardStatsEpsilon && age < s.maxStatsPersistIntervalSeconds { + shouldPersist = false + } + } + + if !shouldPersist { + // Skip persisting, proceed to next shard. + continue + } + + stats.SmoothedLoad = newSmoothed + stats.LastUpdateTime = now + + payload, err := json.Marshal(stats) + if err != nil { + s.logger.Warn( + "failed to marshal shard statistics after heartbeat", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + continue + } + + _, err = s.client.Put(ctx, shardStatsKey, string(payload)) + if err != nil { + s.logger.Warn( + "failed to persist shard statistics from heartbeat", + tag.ShardNamespace(namespace), + tag.ShardExecutor(executorID), + tag.ShardKey(shardID), + tag.Error(err), + ) + } + } +} + // GetHeartbeat retrieves the last known heartbeat state for a single executor. func (s *executorStoreImpl) GetHeartbeat(ctx context.Context, namespace string, executorID string) (*store.HeartbeatState, *store.AssignedState, error) { // The prefix for all keys related to a single executor. @@ -741,3 +883,16 @@ func (s *executorStoreImpl) applyShardStatisticsUpdates(ctx context.Context, nam } } } + +func ewmaSmoothedLoad(prev, current float64, lastUpdate, now int64) float64 { + const tauSeconds = 30.0 // smaller = more responsive, larger = smoother + if lastUpdate <= 0 || tauSeconds <= 0 { + return current + } + dt := max(now-lastUpdate, 0) + alpha := 1 - math.Exp(-float64(dt)/tauSeconds) + if math.IsNaN(prev) || math.IsInf(prev, 0) { + return current + } + return (1-alpha)*prev + alpha*current +} diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go b/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go index a47565768bd..9d6d5c7f7e3 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/fx/fxtest" + "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/sharddistributor/store" @@ -91,6 +92,154 @@ func TestRecordHeartbeat(t *testing.T) { assert.Equal(t, "value-2", string(resp.Kvs[0].Value)) } +func TestRecordHeartbeatUpdatesShardStatistics(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + executorStore := createStore(t, tc) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "executor-shard-stats" + shardID := "shard-with-load" + + initialStats := store.ShardStatistics{ + SmoothedLoad: 1.23, + LastUpdateTime: 10, + LastMoveTime: 123, + } + + shardStatsKey, err := etcdkeys.BuildShardKey(tc.EtcdPrefix, tc.Namespace, shardID, etcdkeys.ShardStatisticsKey) + require.NoError(t, err) + payload, err := json.Marshal(initialStats) + require.NoError(t, err) + _, err = tc.Client.Put(ctx, shardStatsKey, string(payload)) + require.NoError(t, err) + + nowTS := time.Now().Unix() + + req := store.HeartbeatState{ + LastHeartbeat: nowTS, + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: { + Status: types.ShardStatusREADY, + ShardLoad: 45.6, + }, + }, + } + + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, req)) + + nsState, err := executorStore.GetState(ctx, tc.Namespace) + require.NoError(t, err) + + require.Contains(t, nsState.ShardStats, shardID) + updated := nsState.ShardStats[shardID] + + assert.InDelta(t, 45.6, updated.SmoothedLoad, 1e-9) + assert.GreaterOrEqual(t, updated.LastUpdateTime, nowTS) + assert.Equal(t, initialStats.LastMoveTime, updated.LastMoveTime) +} + +func TestRecordHeartbeatSkipsShardStatisticsWithNilReport(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + executorStore := createStore(t, tc) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "executor-missing-load" + validShardID := "shard-with-valid-load" + skippedShardID := "shard-missing-load" + + nowTS := time.Now().Unix() + + req := store.HeartbeatState{ + LastHeartbeat: nowTS, + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + validShardID: { + Status: types.ShardStatusREADY, + ShardLoad: 3.21, + }, + skippedShardID: nil, + }, + } + + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, req)) + + nsState, err := executorStore.GetState(ctx, tc.Namespace) + require.NoError(t, err) + + require.Contains(t, nsState.ShardStats, validShardID) + validStats := nsState.ShardStats[validShardID] + assert.InDelta(t, 3.21, validStats.SmoothedLoad, 1e-9) + assert.Greater(t, validStats.LastUpdateTime, int64(0)) + + assert.NotContains(t, nsState.ShardStats, skippedShardID) +} + +func TestRecordHeartbeatShardStatisticsThrottlesWrites(t *testing.T) { + tc := testhelper.SetupStoreTestCluster(t) + tc.LeaderCfg.Process.HeartbeatTTL = 10 * time.Second + tc.LeaderCfg.Process.ShardStatsTTL = 10 * time.Second + mockTS := clock.NewMockedTimeSourceAt(time.Unix(1000, 0)) + executorStore := createStoreWithTimeSource(t, tc, mockTS) + esImpl, ok := executorStore.(*executorStoreImpl) + require.True(t, ok, "unexpected store implementation") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + executorID := "executor-shard-stats-throttle" + shardID := "shard-stats-throttle" + + baseLoad := 0.40 + smallDelta := shardStatsEpsilon / 2 + intervalSeconds := esImpl.maxStatsPersistIntervalSeconds + halfIntervalSeconds := intervalSeconds / 2 + if halfIntervalSeconds == 0 { + halfIntervalSeconds = 1 + } + + // First heartbeat should always persist stats. + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{ + LastHeartbeat: mockTS.Now().Unix(), + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: {Status: types.ShardStatusREADY, ShardLoad: baseLoad}, + }, + })) + statsAfterFirst := getShardStats(ctx, t, executorStore, tc.Namespace, shardID) + require.NotNil(t, statsAfterFirst) + + // Advance time by less than the persist interval and provide a small delta: should skip the write. + mockTS.Advance(time.Duration(halfIntervalSeconds) * time.Second) + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{ + LastHeartbeat: mockTS.Now().Unix(), + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: {Status: types.ShardStatusREADY, ShardLoad: baseLoad + smallDelta}, + }, + })) + statsAfterSkip := getShardStats(ctx, t, executorStore, tc.Namespace, shardID) + require.NotNil(t, statsAfterSkip) + assert.Equal(t, statsAfterFirst.LastUpdateTime, statsAfterSkip.LastUpdateTime, "small recent deltas should not trigger a persist") + + // Advance time beyond the max persist interval, even small deltas should now persist. + mockTS.Advance(time.Duration(intervalSeconds) * time.Second) + require.NoError(t, executorStore.RecordHeartbeat(ctx, tc.Namespace, executorID, store.HeartbeatState{ + LastHeartbeat: mockTS.Now().Unix(), + Status: types.ExecutorStatusACTIVE, + ReportedShards: map[string]*types.ShardStatusReport{ + shardID: {Status: types.ShardStatusREADY, ShardLoad: baseLoad + smallDelta/2}, + }, + })) + statsAfterForce := getShardStats(ctx, t, executorStore, tc.Namespace, shardID) + require.NotNil(t, statsAfterForce) + assert.Greater(t, statsAfterForce.LastUpdateTime, statsAfterSkip.LastUpdateTime, "stale stats must be refreshed even if delta is small") +} + func TestGetHeartbeat(t *testing.T) { tc := testhelper.SetupStoreTestCluster(t) executorStore := createStore(t, tc) @@ -608,3 +757,27 @@ func createStore(t *testing.T, tc *testhelper.StoreTestCluster) store.Store { require.NoError(t, err) return store } + +func createStoreWithTimeSource(t *testing.T, tc *testhelper.StoreTestCluster, ts clock.TimeSource) store.Store { + t.Helper() + store, err := NewStore(ExecutorStoreParams{ + Client: tc.Client, + Cfg: tc.LeaderCfg, + Lifecycle: fxtest.NewLifecycle(t), + Logger: testlogger.New(t), + TimeSource: ts, + }) + require.NoError(t, err) + return store +} + +func getShardStats(ctx context.Context, t *testing.T, s store.Store, namespace, shardID string) *store.ShardStatistics { + t.Helper() + nsState, err := s.GetState(ctx, namespace) + require.NoError(t, err) + stats, ok := nsState.ShardStats[shardID] + if !ok { + return nil + } + return &stats +}