diff --git a/cmd/sharddistributor-canary/main.go b/cmd/sharddistributor-canary/main.go index 109768515d9..8de73dabe31 100644 --- a/cmd/sharddistributor-canary/main.go +++ b/cmd/sharddistributor-canary/main.go @@ -110,14 +110,23 @@ func opts(fixedNamespace, ephemeralNamespace, endpoint string, canaryGRPCPort in fx.Provide(zap.NewDevelopment), fx.Provide(log.NewLogger), - // Register canary procedures with dispatcher - fx.Invoke(func(dispatcher *yarpc.Dispatcher, server sharddistributorv1.ShardDistributorExecutorCanaryAPIYARPCServer) { + // We do decorate instead of Invoke because we want to start and stop the dispatcher at the + // correct time. + // It will start before all dependencies are started and stop after all dependencies are stopped. + // The Decorate gives fx enough information, so it can start and stop the dispatcher at the correct time. + // + // It is critical to start and stop the dispatcher at the correct time. + // Since the executors need to + // be able to send a final "drain" request to the shard distributor before the application is stopped. + fx.Decorate(func( + lc fx.Lifecycle, + dispatcher *yarpc.Dispatcher, + server sharddistributorv1.ShardDistributorExecutorCanaryAPIYARPCServer, + ) *yarpc.Dispatcher { + // Register canary procedures and ensure dispatcher lifecycle is managed by fx. dispatcher.Register(sharddistributorv1.BuildShardDistributorExecutorCanaryAPIYARPCProcedures(server)) - }), - - // Start the YARPC dispatcher - fx.Invoke(func(lc fx.Lifecycle, dispatcher *yarpc.Dispatcher) { lc.Append(fx.StartStopHook(dispatcher.Start, dispatcher.Stop)) + return dispatcher }), // Include the canary module - it will set up spectator peer choosers and canary client diff --git a/service/sharddistributor/canary/module_test.go b/service/sharddistributor/canary/module_test.go index f3b3c91bdc2..7a738359424 100644 --- a/service/sharddistributor/canary/module_test.go +++ b/service/sharddistributor/canary/module_test.go @@ -8,6 +8,7 @@ import ( "github.com/uber-go/tally" "go.uber.org/fx" "go.uber.org/fx/fxtest" + ubergomock "go.uber.org/mock/gomock" "go.uber.org/yarpc" "go.uber.org/yarpc/api/peer" "go.uber.org/yarpc/api/transport/transporttest" @@ -15,24 +16,36 @@ import ( "go.uber.org/yarpc/yarpctest" "go.uber.org/zap/zaptest" + sharddistributorv1 "github.com/uber/cadence/.gen/proto/sharddistributor/v1" "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log" "github.com/uber/cadence/service/sharddistributor/client/clientcommon" + "github.com/uber/cadence/service/sharddistributor/client/executorclient" ) func TestModule(t *testing.T) { // Create mocks ctrl := gomock.NewController(t) + uberCtrl := ubergomock.NewController(t) + mockLogger := log.NewNoop() + mockClientConfig := transporttest.NewMockClientConfig(ctrl) transport := grpc.NewTransport() outbound := transport.NewOutbound(yarpctest.NewFakePeerList()) - mockClientConfig.EXPECT().Caller().Return("test-executor").Times(2) - mockClientConfig.EXPECT().Service().Return("shard-distributor").Times(2) - mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound).Times(2) + mockClientConfig.EXPECT().Caller().Return("test-executor").AnyTimes() + mockClientConfig.EXPECT().Service().Return("shard-distributor").AnyTimes() + mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound).AnyTimes() mockClientConfigProvider := transporttest.NewMockClientConfigProvider(ctrl) - mockClientConfigProvider.EXPECT().ClientConfig("cadence-shard-distributor").Return(mockClientConfig).Times(2) + mockClientConfigProvider.EXPECT().ClientConfig("cadence-shard-distributor").Return(mockClientConfig).AnyTimes() + + // Create executor yarpc client mock + mockYARPCClient := executorclient.NewMockShardDistributorExecutorAPIYARPCClient(uberCtrl) + mockYARPCClient.EXPECT(). + Heartbeat(ubergomock.Any(), ubergomock.Any(), ubergomock.Any()). + Return(&sharddistributorv1.HeartbeatResponse{}, nil). + AnyTimes() config := clientcommon.Config{ Namespaces: []clientcommon.NamespaceConfig{ @@ -60,13 +73,22 @@ func TestModule(t *testing.T) { fx.Supply( fx.Annotate(tally.NoopScope, fx.As(new(tally.Scope))), fx.Annotate(clock.NewMockedTimeSource(), fx.As(new(clock.TimeSource))), - fx.Annotate(log.NewNoop(), fx.As(new(log.Logger))), + fx.Annotate(mockLogger, fx.As(new(log.Logger))), fx.Annotate(mockClientConfigProvider, fx.As(new(yarpc.ClientConfig))), fx.Annotate(transport, fx.As(new(peer.Transport))), zaptest.NewLogger(t), config, dispatcher, ), - Module(NamespacesNames{FixedNamespace: "shard-distributor-canary", EphemeralNamespace: "shard-distributor-canary-ephemeral", ExternalAssignmentNamespace: "test-external-assignment", SharddistributorServiceName: "cadence-shard-distributor"}), + // Replacing the real YARPC client with mock to handle the draining heartbeat + fx.Decorate(func() sharddistributorv1.ShardDistributorExecutorAPIYARPCClient { + return mockYARPCClient + }), + Module(NamespacesNames{ + FixedNamespace: "shard-distributor-canary", + EphemeralNamespace: "shard-distributor-canary-ephemeral", + ExternalAssignmentNamespace: "test-external-assignment", + SharddistributorServiceName: "cadence-shard-distributor", + }), ).RequireStart().RequireStop() } diff --git a/service/sharddistributor/client/executorclient/client.go b/service/sharddistributor/client/executorclient/client.go index 1411007b2c5..77a527a0cb3 100644 --- a/service/sharddistributor/client/executorclient/client.go +++ b/service/sharddistributor/client/executorclient/client.go @@ -21,6 +21,7 @@ import ( ) //go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination interface_mock.go . ShardProcessorFactory,ShardProcessor,Executor +//go:generate mockgen -package $GOPACKAGE -destination yarpc_client_mock.go github.com/uber/cadence/.gen/proto/sharddistributor/v1 ShardDistributorExecutorAPIYARPCClient type ExecutorMetadata map[string]string diff --git a/service/sharddistributor/client/executorclient/client_test.go b/service/sharddistributor/client/executorclient/client_test.go index 341cf547369..1ca010f616a 100644 --- a/service/sharddistributor/client/executorclient/client_test.go +++ b/service/sharddistributor/client/executorclient/client_test.go @@ -4,14 +4,10 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/uber-go/tally" "go.uber.org/fx" "go.uber.org/fx/fxtest" - uber_gomock "go.uber.org/mock/gomock" - "go.uber.org/yarpc/api/transport/transporttest" - "go.uber.org/yarpc/transport/grpc" - "go.uber.org/yarpc/yarpctest" + "go.uber.org/mock/gomock" sharddistributorv1 "github.com/uber/cadence/.gen/proto/sharddistributor/v1" "github.com/uber/cadence/common/clock" @@ -20,21 +16,17 @@ import ( ) func TestModule(t *testing.T) { - // Create mocks ctrl := gomock.NewController(t) - uberCtrl := uber_gomock.NewController(t) mockLogger := log.NewNoop() - mockShardProcessorFactory := NewMockShardProcessorFactory[*MockShardProcessor](uberCtrl) + // Create executor yarpc client mock + mockYARPCClient := NewMockShardDistributorExecutorAPIYARPCClient(ctrl) + mockYARPCClient.EXPECT(). + Heartbeat(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&sharddistributorv1.HeartbeatResponse{}, nil). + AnyTimes() - // Create shard distributor yarpc client - outbound := grpc.NewTransport().NewOutbound(yarpctest.NewFakePeerList()) - - mockClientConfig := transporttest.NewMockClientConfig(ctrl) - mockClientConfig.EXPECT().Caller().Return("test-executor") - mockClientConfig.EXPECT().Service().Return("shard-distributor") - mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound) - yarpcClient := sharddistributorv1.NewShardDistributorExecutorAPIYARPCClient(mockClientConfig) + mockShardProcessorFactory := NewMockShardProcessorFactory[*MockShardProcessor](ctrl) // Example config config := clientcommon.Config{ @@ -49,7 +41,7 @@ func TestModule(t *testing.T) { // Create a test app with the library, check that it starts and stops fxtest.New(t, fx.Supply( - fx.Annotate(yarpcClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))), + fx.Annotate(mockYARPCClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))), fx.Annotate(tally.NoopScope, fx.As(new(tally.Scope))), fx.Annotate(mockLogger, fx.As(new(log.Logger))), fx.Annotate(mockShardProcessorFactory, fx.As(new(ShardProcessorFactory[*MockShardProcessor]))), @@ -70,22 +62,18 @@ type MockShardProcessor2 struct { } func TestModuleWithNamespace(t *testing.T) { - // Create mocks ctrl := gomock.NewController(t) - uberCtrl := uber_gomock.NewController(t) mockLogger := log.NewNoop() - mockFactory1 := NewMockShardProcessorFactory[*MockShardProcessor1](uberCtrl) - mockFactory2 := NewMockShardProcessorFactory[*MockShardProcessor2](uberCtrl) - - // Create shard distributor yarpc client - outbound := grpc.NewTransport().NewOutbound(yarpctest.NewFakePeerList()) + // Create executor yarpc client mock + mockYARPCClient := NewMockShardDistributorExecutorAPIYARPCClient(ctrl) + mockYARPCClient.EXPECT(). + Heartbeat(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&sharddistributorv1.HeartbeatResponse{}, nil). + AnyTimes() - mockClientConfig := transporttest.NewMockClientConfig(ctrl) - mockClientConfig.EXPECT().Caller().Return("test-executor").AnyTimes() - mockClientConfig.EXPECT().Service().Return("shard-distributor").AnyTimes() - mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound).AnyTimes() - yarpcClient := sharddistributorv1.NewShardDistributorExecutorAPIYARPCClient(mockClientConfig) + mockFactory1 := NewMockShardProcessorFactory[*MockShardProcessor1](ctrl) + mockFactory2 := NewMockShardProcessorFactory[*MockShardProcessor2](ctrl) // Multi-namespace config config := clientcommon.Config{ @@ -104,7 +92,7 @@ func TestModuleWithNamespace(t *testing.T) { // Create a test app with two namespace-specific modules using different processor types fxtest.New(t, fx.Supply( - fx.Annotate(yarpcClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))), + fx.Annotate(mockYARPCClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))), fx.Annotate(tally.NoopScope, fx.As(new(tally.Scope))), fx.Annotate(mockLogger, fx.As(new(log.Logger))), fx.Annotate(clock.NewMockedTimeSource(), fx.As(new(clock.TimeSource))), diff --git a/service/sharddistributor/client/executorclient/clientimpl.go b/service/sharddistributor/client/executorclient/clientimpl.go index c7265f32367..1429a9227ef 100644 --- a/service/sharddistributor/client/executorclient/clientimpl.go +++ b/service/sharddistributor/client/executorclient/clientimpl.go @@ -34,7 +34,8 @@ const ( ) const ( - heartbeatJitterCoeff = 0.1 // 10% jitter + heartbeatJitterCoeff = 0.1 // 10% jitter + drainingHeartbeatTimeout = 5 * time.Second ) type managedProcessor[SP ShardProcessor] struct { @@ -188,12 +189,14 @@ func (e *executorImpl[SP]) heartbeatloop(ctx context.Context) { for { select { case <-ctx.Done(): - e.logger.Info("shard distributorexecutor context done, stopping") + e.logger.Info("shard distributor executor context done, stopping") e.stopShardProcessors() + e.sendDrainingHeartbeat() return case <-e.stopC: - e.logger.Info("shard distributorexecutor stopped") + e.logger.Info("shard distributor executor stopped") e.stopShardProcessors() + e.sendDrainingHeartbeat() return case <-heartBeatTimer.Chan(): heartBeatTimer.Reset(backoff.JitDuration(e.heartBeatInterval, heartbeatJitterCoeff)) @@ -273,6 +276,10 @@ func (e *executorImpl[SP]) updateShardAssignmentMetered(ctx context.Context, sha } func (e *executorImpl[SP]) heartbeat(ctx context.Context) (shardAssignments map[string]*types.ShardAssignment, migrationMode types.MigrationMode, err error) { + return e.sendHeartbeat(ctx, types.ExecutorStatusACTIVE) +} + +func (e *executorImpl[SP]) sendHeartbeat(ctx context.Context, status types.ExecutorStatus) (map[string]*types.ShardAssignment, types.MigrationMode, error) { // Fill in the shard status reports shardStatusReports := make(map[string]*types.ShardStatusReport) e.managedProcessors.Range(func(shardID string, managedProcessor *managedProcessor[SP]) bool { @@ -293,7 +300,7 @@ func (e *executorImpl[SP]) heartbeat(ctx context.Context) (shardAssignments map[ request := &types.ExecutorHeartbeatRequest{ Namespace: e.namespace, ExecutorID: e.executorID, - Status: types.ExecutorStatusACTIVE, + Status: status, ShardStatusReports: shardStatusReports, Metadata: e.metadata.Get(), } @@ -318,6 +325,16 @@ func (e *executorImpl[SP]) heartbeat(ctx context.Context) (shardAssignments map[ return response.ShardAssignments, response.MigrationMode, nil } +func (e *executorImpl[SP]) sendDrainingHeartbeat() { + ctx, cancel := context.WithTimeout(context.Background(), drainingHeartbeatTimeout) + defer cancel() + + _, _, err := e.sendHeartbeat(ctx, types.ExecutorStatusDRAINING) + if err != nil { + e.logger.Error("failed to send draining heartbeat", tag.Error(err)) + } +} + func (e *executorImpl[SP]) updateShardAssignment(ctx context.Context, shardAssignments map[string]*types.ShardAssignment) { wg := sync.WaitGroup{} diff --git a/service/sharddistributor/client/executorclient/clientimpl_test.go b/service/sharddistributor/client/executorclient/clientimpl_test.go index 6c260961449..ed5d2ebfba7 100755 --- a/service/sharddistributor/client/executorclient/clientimpl_test.go +++ b/service/sharddistributor/client/executorclient/clientimpl_test.go @@ -11,6 +11,7 @@ import ( "github.com/uber-go/tally" "go.uber.org/goleak" "go.uber.org/mock/gomock" + yarpc "go.uber.org/yarpc" "github.com/uber/cadence/client/sharddistributorexecutor" "github.com/uber/cadence/common/clock" @@ -19,8 +20,39 @@ import ( "github.com/uber/cadence/service/sharddistributor/client/executorclient/syncgeneric" ) -func TestHeartBeartLoop(t *testing.T) { - // Insure that there are no goroutines leaked +func expectDrainingHeartbeat(t *testing.T, mockClient *sharddistributorexecutor.MockClient) { + mockClient.EXPECT(). + Heartbeat(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *types.ExecutorHeartbeatRequest, _ ...yarpc.CallOption) (*types.ExecutorHeartbeatResponse, error) { + assert.Equal(t, types.ExecutorStatusDRAINING, req.Status) + return &types.ExecutorHeartbeatResponse{}, nil + }) +} + +func newTestExecutor( + client sharddistributorexecutor.Client, + factory ShardProcessorFactory[*MockShardProcessor], + timeSource clock.TimeSource, +) *executorImpl[*MockShardProcessor] { + if timeSource == nil { + timeSource = clock.NewMockedTimeSource() + } + return &executorImpl[*MockShardProcessor]{ + logger: log.NewNoop(), + metrics: tally.NoopScope, + shardDistributorClient: client, + shardProcessorFactory: factory, + namespace: "test-namespace", + stopC: make(chan struct{}), + heartBeatInterval: 10 * time.Second, + managedProcessors: syncgeneric.Map[string, *managedProcessor[*MockShardProcessor]]{}, + executorID: "test-executor-id", + timeSource: timeSource, + } +} + +func TestHeartBeatLoop(t *testing.T) { + // Ensure that there are no goroutines leaked defer goleak.VerifyNone(t) // Create mocks @@ -44,6 +76,7 @@ func TestHeartBeartLoop(t *testing.T) { }, MigrationMode: types.MigrationModeONBOARDED, }, nil) + expectDrainingHeartbeat(t, mockShardDistributorClient) // The two shards are assigned to the executor, so we expect them to be created, started and stopped mockShardProcessor1 := NewMockShardProcessor(ctrl) @@ -62,18 +95,7 @@ func TestHeartBeartLoop(t *testing.T) { mockTimeSource := clock.NewMockedTimeSource() // Create the executor - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - shardDistributorClient: mockShardDistributorClient, - shardProcessorFactory: mockShardProcessorFactory, - namespace: "test-namespace", - stopC: make(chan struct{}), - heartBeatInterval: 10 * time.Second, - managedProcessors: syncgeneric.Map[string, *managedProcessor[*MockShardProcessor]]{}, - executorID: "test-executor-id", - timeSource: mockTimeSource, - } + executor := newTestExecutor(mockShardDistributorClient, mockShardProcessorFactory, mockTimeSource) // Start the executor, and defer stopping it executor.Start(context.Background()) @@ -130,13 +152,7 @@ func TestHeartbeat(t *testing.T) { shardProcessorMock2.EXPECT().GetShardReport().Return(ShardReport{ShardLoad: 0.456, Status: types.ShardStatusREADY}) // Create the executor - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardDistributorClient: shardDistributorClient, - namespace: "test-namespace", - executorID: "test-executor-id", - metrics: tally.NoopScope, - } + executor := newTestExecutor(shardDistributorClient, nil, nil) executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -152,7 +168,7 @@ func TestHeartbeat(t *testing.T) { assert.Equal(t, types.AssignmentStatusREADY, shardAssignments["test-shard-id3"].Status) } -func TestHeartBeartLoop_ShardAssignmentChange(t *testing.T) { +func TestHeartBeatLoop_ShardAssignmentChange(t *testing.T) { ctrl := gomock.NewController(t) // Setup mocks @@ -164,11 +180,7 @@ func TestHeartBeartLoop_ShardAssignmentChange(t *testing.T) { shardProcessorFactory.EXPECT().NewShardProcessor(gomock.Any()).Return(shardProcessorMock3, nil) // Create the executor currently has shards 1 and 2 assigned to it - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardProcessorFactory: shardProcessorFactory, - metrics: tally.NoopScope, - } + executor := newTestExecutor(nil, shardProcessorFactory, nil) executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -214,11 +226,7 @@ func TestAssignShardsFromLocalLogic(t *testing.T) { name: "AssignShardsFromLocalLogic fails if the namespace is onboarded", params: map[string]*types.ShardAssignment{}, setup: func() *executorImpl[*MockShardProcessor] { - // Create the executor currently has shards 1 and 2 assigned to it - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - } + executor := newTestExecutor(nil, nil, nil) executor.setMigrationMode(types.MigrationModeONBOARDED) return executor }, @@ -238,12 +246,7 @@ func TestAssignShardsFromLocalLogic(t *testing.T) { // Setup mocks shardProcessorFactory.EXPECT().NewShardProcessor(gomock.Any()).Return(nil, assert.AnError) - // Create the executor currently has shards 1 and 2 assigned to it - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardProcessorFactory: shardProcessorFactory, - metrics: tally.NoopScope, - } + executor := newTestExecutor(nil, shardProcessorFactory, nil) executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -267,12 +270,7 @@ func TestAssignShardsFromLocalLogic(t *testing.T) { shardProcessorFactory.EXPECT().NewShardProcessor(gomock.Any()).Return(shardProcessorMock3, nil) - // Create the executor currently has shards 1 and 2 assigned to it - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardProcessorFactory: shardProcessorFactory, - metrics: tally.NoopScope, - } + executor := newTestExecutor(nil, shardProcessorFactory, nil) executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -284,7 +282,7 @@ func TestAssignShardsFromLocalLogic(t *testing.T) { shardProcessorMock3.EXPECT().GetShardReport().Return(ShardReport{Status: types.ShardStatusREADY}) return executor }, - assert: func(err error, executor *executorImpl[*MockShardProcessor]) { + assert: func(_ error, executor *executorImpl[*MockShardProcessor]) { // Assert that we now have the 3 shards in the assignment processor1, err := executor.GetShardProcess(context.Background(), "test-shard-id1") assert.NoError(t, err) @@ -323,11 +321,7 @@ func TestRemoveShardsFromLocalLogic(t *testing.T) { name: "RemoveShardsFromLocalLogic fails if the namespace is onboarded", params: []string{}, setup: func() *executorImpl[*MockShardProcessor] { - // Create the executor currently has shards 1 and 2 assigned to it - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - } + executor := newTestExecutor(nil, nil, nil) executor.setMigrationMode(types.MigrationModeONBOARDED) return executor }, @@ -341,11 +335,7 @@ func TestRemoveShardsFromLocalLogic(t *testing.T) { setup: func() *executorImpl[*MockShardProcessor] { shardProcessorMock1 := NewMockShardProcessor(ctrl) shardProcessorMock2 := NewMockShardProcessor(ctrl) - // Create the executor currently has shards 1 and 2 assigned to it - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - } + executor := newTestExecutor(nil, nil, nil) executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -355,7 +345,7 @@ func TestRemoveShardsFromLocalLogic(t *testing.T) { shardProcessorMock1.EXPECT().GetShardReport().Return(ShardReport{Status: types.ShardStatusREADY}) return executor }, - assert: func(err error, executor *executorImpl[*MockShardProcessor]) { + assert: func(_ error, executor *executorImpl[*MockShardProcessor]) { // Assert that we now have the 1 shard in the assignment processor1, err := executor.GetShardProcess(context.Background(), "test-shard-id1") assert.NoError(t, err) @@ -397,13 +387,7 @@ func TestHeartbeat_WithMigrationMode(t *testing.T) { MigrationMode: types.MigrationModeDISTRIBUTEDPASSTHROUGH, }, nil) - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardDistributorClient: shardDistributorClient, - namespace: "test-namespace", - executorID: "test-executor-id", - metrics: tally.NoopScope, - } + executor := newTestExecutor(shardDistributorClient, nil, nil) executor.setMigrationMode(types.MigrationModeINVALID) shardAssignments, migrationMode, err := executor.heartbeat(context.Background()) @@ -423,13 +407,7 @@ func TestHeartbeat_MigrationModeTransition(t *testing.T) { MigrationMode: types.MigrationModeONBOARDED, }, nil) - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardDistributorClient: shardDistributorClient, - namespace: "test-namespace", - executorID: "test-executor-id", - metrics: tally.NoopScope, - } + executor := newTestExecutor(shardDistributorClient, nil, nil) executor.setMigrationMode(types.MigrationModeDISTRIBUTEDPASSTHROUGH) _, migrationMode, err := executor.heartbeat(context.Background()) @@ -450,17 +428,7 @@ func TestHeartbeatLoop_LocalPassthrough_SkipsHeartbeat(t *testing.T) { mockTimeSource := clock.NewMockedTimeSource() - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - shardDistributorClient: mockShardDistributorClient, - namespace: "test-namespace", - stopC: make(chan struct{}), - heartBeatInterval: 10 * time.Second, - managedProcessors: syncgeneric.Map[string, *managedProcessor[*MockShardProcessor]]{}, - executorID: "test-executor-id", - timeSource: mockTimeSource, - } + executor := newTestExecutor(mockShardDistributorClient, nil, mockTimeSource) executor.setMigrationMode(types.MigrationModeLOCALPASSTHROUGH) executor.Start(context.Background()) @@ -484,6 +452,7 @@ func TestHeartbeatLoop_LocalPassthroughShadow_SkipsAssignment(t *testing.T) { }, MigrationMode: types.MigrationModeLOCALPASSTHROUGHSHADOW, }, nil) + expectDrainingHeartbeat(t, mockShardDistributorClient) mockShardProcessorFactory := NewMockShardProcessorFactory[*MockShardProcessor](ctrl) // No shard processor should be created @@ -491,18 +460,7 @@ func TestHeartbeatLoop_LocalPassthroughShadow_SkipsAssignment(t *testing.T) { mockTimeSource := clock.NewMockedTimeSource() - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - shardDistributorClient: mockShardDistributorClient, - shardProcessorFactory: mockShardProcessorFactory, - namespace: "test-namespace", - stopC: make(chan struct{}), - heartBeatInterval: 10 * time.Second, - managedProcessors: syncgeneric.Map[string, *managedProcessor[*MockShardProcessor]]{}, - executorID: "test-executor-id", - timeSource: mockTimeSource, - } + executor := newTestExecutor(mockShardDistributorClient, mockShardProcessorFactory, mockTimeSource) executor.setMigrationMode(types.MigrationModeONBOARDED) executor.Start(context.Background()) @@ -531,6 +489,7 @@ func TestHeartbeatLoop_DistributedPassthrough_AppliesAssignment(t *testing.T) { }, MigrationMode: types.MigrationModeDISTRIBUTEDPASSTHROUGH, }, nil) + expectDrainingHeartbeat(t, mockShardDistributorClient) mockShardProcessor := NewMockShardProcessor(ctrl) mockShardProcessor.EXPECT().Start(gomock.Any()) @@ -541,18 +500,7 @@ func TestHeartbeatLoop_DistributedPassthrough_AppliesAssignment(t *testing.T) { mockTimeSource := clock.NewMockedTimeSource() - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: tally.NoopScope, - shardDistributorClient: mockShardDistributorClient, - shardProcessorFactory: mockShardProcessorFactory, - namespace: "test-namespace", - stopC: make(chan struct{}), - heartBeatInterval: 10 * time.Second, - managedProcessors: syncgeneric.Map[string, *managedProcessor[*MockShardProcessor]]{}, - executorID: "test-executor-id", - timeSource: mockTimeSource, - } + executor := newTestExecutor(mockShardDistributorClient, mockShardProcessorFactory, mockTimeSource) executor.setMigrationMode(types.MigrationModeONBOARDED) executor.Start(context.Background()) @@ -569,6 +517,47 @@ func TestHeartbeatLoop_DistributedPassthrough_AppliesAssignment(t *testing.T) { assert.Equal(t, mockShardProcessor, processor) } +func TestHeartbeatLoop_StopSignalSendsDrainingHeartbeat(t *testing.T) { + defer goleak.VerifyNone(t) + + ctrl := gomock.NewController(t) + mockShardDistributorClient := sharddistributorexecutor.NewMockClient(ctrl) + expectDrainingHeartbeat(t, mockShardDistributorClient) + + mockTimeSource := clock.NewMockedTimeSource() + + executor := newTestExecutor(mockShardDistributorClient, nil, mockTimeSource) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + go func() { + executor.heartbeatloop(ctx) + close(done) + }() + + close(executor.stopC) + <-done +} + +func TestHeartbeatLoop_ContextCancelSendsDrainingHeartbeat(t *testing.T) { + defer goleak.VerifyNone(t) + + ctrl := gomock.NewController(t) + mockShardDistributorClient := sharddistributorexecutor.NewMockClient(ctrl) + expectDrainingHeartbeat(t, mockShardDistributorClient) + + mockTimeSource := clock.NewMockedTimeSource() + + executor := newTestExecutor(mockShardDistributorClient, nil, mockTimeSource) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancelling the context + + executor.heartbeatloop(ctx) +} + func TestCompareAssignments_Converged(t *testing.T) { ctrl := gomock.NewController(t) @@ -576,10 +565,8 @@ func TestCompareAssignments_Converged(t *testing.T) { shardProcessorMock2 := NewMockShardProcessor(ctrl) testScope := tally.NewTestScope("test", nil) - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: testScope, - } + executor := newTestExecutor(nil, nil, nil) + executor.metrics = testScope executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -603,10 +590,8 @@ func TestCompareAssignments_Diverged_MissingShard(t *testing.T) { shardProcessorMock2 := NewMockShardProcessor(ctrl) testScope := tally.NewTestScope("test", nil) - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: testScope, - } + executor := newTestExecutor(nil, nil, nil) + executor.metrics = testScope executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) executor.managedProcessors.Store("test-shard-id2", newManagedProcessor(shardProcessorMock2, processorStateStarted)) @@ -629,10 +614,8 @@ func TestCompareAssignments_Diverged_ExtraShard(t *testing.T) { shardProcessorMock1 := NewMockShardProcessor(ctrl) testScope := tally.NewTestScope("test", nil) - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: testScope, - } + executor := newTestExecutor(nil, nil, nil) + executor.metrics = testScope executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) @@ -655,10 +638,8 @@ func TestCompareAssignments_Diverged_WrongStatus(t *testing.T) { shardProcessorMock1 := NewMockShardProcessor(ctrl) testScope := tally.NewTestScope("test", nil) - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - metrics: testScope, - } + executor := newTestExecutor(nil, nil, nil) + executor.metrics = testScope executor.managedProcessors.Store("test-shard-id1", newManagedProcessor(shardProcessorMock1, processorStateStarted)) @@ -739,13 +720,7 @@ func TestGetShardProcess_NonOwnedShard_Fails(t *testing.T) { tc.setupMocks(shardProcessorFactory, NewMockShardProcessor(ctrl)) } - executor := &executorImpl[*MockShardProcessor]{ - logger: log.NewNoop(), - shardProcessorFactory: shardProcessorFactory, - metrics: tally.NoopScope, - shardDistributorClient: shardDistributorClient, - timeSource: clock.NewMockedTimeSource(), - } + executor := newTestExecutor(shardDistributorClient, shardProcessorFactory, clock.NewMockedTimeSource()) executor.setMigrationMode(tc.migrationMode) for _, shardID := range tc.shardsInCache { diff --git a/service/sharddistributor/client/executorclient/yarpc_client_mock.go b/service/sharddistributor/client/executorclient/yarpc_client_mock.go new file mode 100644 index 00000000000..0bb11889590 --- /dev/null +++ b/service/sharddistributor/client/executorclient/yarpc_client_mock.go @@ -0,0 +1,64 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/uber/cadence/.gen/proto/sharddistributor/v1 (interfaces: ShardDistributorExecutorAPIYARPCClient) +// +// Generated by this command: +// +// mockgen -package executorclient -destination yarpc_client_mock.go github.com/uber/cadence/.gen/proto/sharddistributor/v1 ShardDistributorExecutorAPIYARPCClient +// + +// Package executorclient is a generated GoMock package. +package executorclient + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + yarpc "go.uber.org/yarpc" + + sharddistributorv1 "github.com/uber/cadence/.gen/proto/sharddistributor/v1" +) + +// MockShardDistributorExecutorAPIYARPCClient is a mock of ShardDistributorExecutorAPIYARPCClient interface. +type MockShardDistributorExecutorAPIYARPCClient struct { + ctrl *gomock.Controller + recorder *MockShardDistributorExecutorAPIYARPCClientMockRecorder + isgomock struct{} +} + +// MockShardDistributorExecutorAPIYARPCClientMockRecorder is the mock recorder for MockShardDistributorExecutorAPIYARPCClient. +type MockShardDistributorExecutorAPIYARPCClientMockRecorder struct { + mock *MockShardDistributorExecutorAPIYARPCClient +} + +// NewMockShardDistributorExecutorAPIYARPCClient creates a new mock instance. +func NewMockShardDistributorExecutorAPIYARPCClient(ctrl *gomock.Controller) *MockShardDistributorExecutorAPIYARPCClient { + mock := &MockShardDistributorExecutorAPIYARPCClient{ctrl: ctrl} + mock.recorder = &MockShardDistributorExecutorAPIYARPCClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockShardDistributorExecutorAPIYARPCClient) EXPECT() *MockShardDistributorExecutorAPIYARPCClientMockRecorder { + return m.recorder +} + +// Heartbeat mocks base method. +func (m *MockShardDistributorExecutorAPIYARPCClient) Heartbeat(arg0 context.Context, arg1 *sharddistributorv1.HeartbeatRequest, arg2 ...yarpc.CallOption) (*sharddistributorv1.HeartbeatResponse, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Heartbeat", varargs...) + ret0, _ := ret[0].(*sharddistributorv1.HeartbeatResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Heartbeat indicates an expected call of Heartbeat. +func (mr *MockShardDistributorExecutorAPIYARPCClientMockRecorder) Heartbeat(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockShardDistributorExecutorAPIYARPCClient)(nil).Heartbeat), varargs...) +}