Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions cmd/sharddistributor-canary/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,23 @@ func opts(fixedNamespace, ephemeralNamespace, endpoint string, canaryGRPCPort in
fx.Provide(zap.NewDevelopment),
fx.Provide(log.NewLogger),

// Register canary procedures with dispatcher
fx.Invoke(func(dispatcher *yarpc.Dispatcher, server sharddistributorv1.ShardDistributorExecutorCanaryAPIYARPCServer) {
// We do decorate instead of Invoke because we want to start and stop the dispatcher at the
// correct time.
// It will start before all dependencies are started and stop after all dependencies are stopped.
// The Decorate gives fx enough information, so it can start and stop the dispatcher at the correct time.
//
// It is critical to start and stop the dispatcher at the correct time.
// Since the executors need to
// be able to send a final "drain" request to the shard distributor before the application is stopped.
fx.Decorate(func(
lc fx.Lifecycle,
dispatcher *yarpc.Dispatcher,
server sharddistributorv1.ShardDistributorExecutorCanaryAPIYARPCServer,
) *yarpc.Dispatcher {
// Register canary procedures and ensure dispatcher lifecycle is managed by fx.
dispatcher.Register(sharddistributorv1.BuildShardDistributorExecutorCanaryAPIYARPCProcedures(server))
}),

// Start the YARPC dispatcher
fx.Invoke(func(lc fx.Lifecycle, dispatcher *yarpc.Dispatcher) {
lc.Append(fx.StartStopHook(dispatcher.Start, dispatcher.Stop))
return dispatcher
}),

// Include the canary module - it will set up spectator peer choosers and canary client
Expand Down
34 changes: 28 additions & 6 deletions service/sharddistributor/canary/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,44 @@ import (
"github.com/uber-go/tally"
"go.uber.org/fx"
"go.uber.org/fx/fxtest"
ubergomock "go.uber.org/mock/gomock"
"go.uber.org/yarpc"
"go.uber.org/yarpc/api/peer"
"go.uber.org/yarpc/api/transport/transporttest"
"go.uber.org/yarpc/transport/grpc"
"go.uber.org/yarpc/yarpctest"
"go.uber.org/zap/zaptest"

sharddistributorv1 "github.com/uber/cadence/.gen/proto/sharddistributor/v1"
"github.com/uber/cadence/common/clock"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/service/sharddistributor/client/clientcommon"
"github.com/uber/cadence/service/sharddistributor/client/executorclient"
)

func TestModule(t *testing.T) {
// Create mocks
ctrl := gomock.NewController(t)
uberCtrl := ubergomock.NewController(t)
mockLogger := log.NewNoop()

mockClientConfig := transporttest.NewMockClientConfig(ctrl)
transport := grpc.NewTransport()
outbound := transport.NewOutbound(yarpctest.NewFakePeerList())

mockClientConfig.EXPECT().Caller().Return("test-executor").Times(2)
mockClientConfig.EXPECT().Service().Return("shard-distributor").Times(2)
mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound).Times(2)
mockClientConfig.EXPECT().Caller().Return("test-executor").AnyTimes()
mockClientConfig.EXPECT().Service().Return("shard-distributor").AnyTimes()
mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound).AnyTimes()

mockClientConfigProvider := transporttest.NewMockClientConfigProvider(ctrl)
mockClientConfigProvider.EXPECT().ClientConfig("cadence-shard-distributor").Return(mockClientConfig).Times(2)
mockClientConfigProvider.EXPECT().ClientConfig("cadence-shard-distributor").Return(mockClientConfig).AnyTimes()

// Create executor yarpc client mock
mockYARPCClient := executorclient.NewMockShardDistributorExecutorAPIYARPCClient(uberCtrl)
mockYARPCClient.EXPECT().
Heartbeat(ubergomock.Any(), ubergomock.Any(), ubergomock.Any()).
Return(&sharddistributorv1.HeartbeatResponse{}, nil).
AnyTimes()

config := clientcommon.Config{
Namespaces: []clientcommon.NamespaceConfig{
Expand Down Expand Up @@ -60,13 +73,22 @@ func TestModule(t *testing.T) {
fx.Supply(
fx.Annotate(tally.NoopScope, fx.As(new(tally.Scope))),
fx.Annotate(clock.NewMockedTimeSource(), fx.As(new(clock.TimeSource))),
fx.Annotate(log.NewNoop(), fx.As(new(log.Logger))),
fx.Annotate(mockLogger, fx.As(new(log.Logger))),
fx.Annotate(mockClientConfigProvider, fx.As(new(yarpc.ClientConfig))),
fx.Annotate(transport, fx.As(new(peer.Transport))),
zaptest.NewLogger(t),
config,
dispatcher,
),
Module(NamespacesNames{FixedNamespace: "shard-distributor-canary", EphemeralNamespace: "shard-distributor-canary-ephemeral", ExternalAssignmentNamespace: "test-external-assignment", SharddistributorServiceName: "cadence-shard-distributor"}),
// Replacing the real YARPC client with mock to handle the draining heartbeat
fx.Decorate(func() sharddistributorv1.ShardDistributorExecutorAPIYARPCClient {
return mockYARPCClient
}),
Module(NamespacesNames{
FixedNamespace: "shard-distributor-canary",
EphemeralNamespace: "shard-distributor-canary-ephemeral",
ExternalAssignmentNamespace: "test-external-assignment",
SharddistributorServiceName: "cadence-shard-distributor",
}),
).RequireStart().RequireStop()
}
1 change: 1 addition & 0 deletions service/sharddistributor/client/executorclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
)

//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination interface_mock.go . ShardProcessorFactory,ShardProcessor,Executor
//go:generate mockgen -package $GOPACKAGE -destination yarpc_client_mock.go github.com/uber/cadence/.gen/proto/sharddistributor/v1 ShardDistributorExecutorAPIYARPCClient

type ExecutorMetadata map[string]string

Expand Down
48 changes: 18 additions & 30 deletions service/sharddistributor/client/executorclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@ import (
"testing"
"time"

"github.com/golang/mock/gomock"
"github.com/uber-go/tally"
"go.uber.org/fx"
"go.uber.org/fx/fxtest"
uber_gomock "go.uber.org/mock/gomock"
"go.uber.org/yarpc/api/transport/transporttest"
"go.uber.org/yarpc/transport/grpc"
"go.uber.org/yarpc/yarpctest"
"go.uber.org/mock/gomock"

sharddistributorv1 "github.com/uber/cadence/.gen/proto/sharddistributor/v1"
"github.com/uber/cadence/common/clock"
Expand All @@ -20,21 +16,17 @@ import (
)

func TestModule(t *testing.T) {
// Create mocks
ctrl := gomock.NewController(t)
uberCtrl := uber_gomock.NewController(t)
mockLogger := log.NewNoop()

mockShardProcessorFactory := NewMockShardProcessorFactory[*MockShardProcessor](uberCtrl)
// Create executor yarpc client mock
mockYARPCClient := NewMockShardDistributorExecutorAPIYARPCClient(ctrl)
mockYARPCClient.EXPECT().
Heartbeat(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&sharddistributorv1.HeartbeatResponse{}, nil).
AnyTimes()

// Create shard distributor yarpc client
outbound := grpc.NewTransport().NewOutbound(yarpctest.NewFakePeerList())

mockClientConfig := transporttest.NewMockClientConfig(ctrl)
mockClientConfig.EXPECT().Caller().Return("test-executor")
mockClientConfig.EXPECT().Service().Return("shard-distributor")
mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound)
yarpcClient := sharddistributorv1.NewShardDistributorExecutorAPIYARPCClient(mockClientConfig)
mockShardProcessorFactory := NewMockShardProcessorFactory[*MockShardProcessor](ctrl)

// Example config
config := clientcommon.Config{
Expand All @@ -49,7 +41,7 @@ func TestModule(t *testing.T) {
// Create a test app with the library, check that it starts and stops
fxtest.New(t,
fx.Supply(
fx.Annotate(yarpcClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))),
fx.Annotate(mockYARPCClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))),
fx.Annotate(tally.NoopScope, fx.As(new(tally.Scope))),
fx.Annotate(mockLogger, fx.As(new(log.Logger))),
fx.Annotate(mockShardProcessorFactory, fx.As(new(ShardProcessorFactory[*MockShardProcessor]))),
Expand All @@ -70,22 +62,18 @@ type MockShardProcessor2 struct {
}

func TestModuleWithNamespace(t *testing.T) {
// Create mocks
ctrl := gomock.NewController(t)
uberCtrl := uber_gomock.NewController(t)
mockLogger := log.NewNoop()

mockFactory1 := NewMockShardProcessorFactory[*MockShardProcessor1](uberCtrl)
mockFactory2 := NewMockShardProcessorFactory[*MockShardProcessor2](uberCtrl)

// Create shard distributor yarpc client
outbound := grpc.NewTransport().NewOutbound(yarpctest.NewFakePeerList())
// Create executor yarpc client mock
mockYARPCClient := NewMockShardDistributorExecutorAPIYARPCClient(ctrl)
mockYARPCClient.EXPECT().
Heartbeat(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&sharddistributorv1.HeartbeatResponse{}, nil).
AnyTimes()

mockClientConfig := transporttest.NewMockClientConfig(ctrl)
mockClientConfig.EXPECT().Caller().Return("test-executor").AnyTimes()
mockClientConfig.EXPECT().Service().Return("shard-distributor").AnyTimes()
mockClientConfig.EXPECT().GetUnaryOutbound().Return(outbound).AnyTimes()
yarpcClient := sharddistributorv1.NewShardDistributorExecutorAPIYARPCClient(mockClientConfig)
mockFactory1 := NewMockShardProcessorFactory[*MockShardProcessor1](ctrl)
mockFactory2 := NewMockShardProcessorFactory[*MockShardProcessor2](ctrl)

// Multi-namespace config
config := clientcommon.Config{
Expand All @@ -104,7 +92,7 @@ func TestModuleWithNamespace(t *testing.T) {
// Create a test app with two namespace-specific modules using different processor types
fxtest.New(t,
fx.Supply(
fx.Annotate(yarpcClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))),
fx.Annotate(mockYARPCClient, fx.As(new(sharddistributorv1.ShardDistributorExecutorAPIYARPCClient))),
fx.Annotate(tally.NoopScope, fx.As(new(tally.Scope))),
fx.Annotate(mockLogger, fx.As(new(log.Logger))),
fx.Annotate(clock.NewMockedTimeSource(), fx.As(new(clock.TimeSource))),
Expand Down
25 changes: 21 additions & 4 deletions service/sharddistributor/client/executorclient/clientimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ const (
)

const (
heartbeatJitterCoeff = 0.1 // 10% jitter
heartbeatJitterCoeff = 0.1 // 10% jitter
drainingHeartbeatTimeout = 5 * time.Second
)

type managedProcessor[SP ShardProcessor] struct {
Expand Down Expand Up @@ -188,12 +189,14 @@ func (e *executorImpl[SP]) heartbeatloop(ctx context.Context) {
for {
select {
case <-ctx.Done():
e.logger.Info("shard distributorexecutor context done, stopping")
e.logger.Info("shard distributor executor context done, stopping")
e.stopShardProcessors()
e.sendDrainingHeartbeat()
return
case <-e.stopC:
e.logger.Info("shard distributorexecutor stopped")
e.logger.Info("shard distributor executor stopped")
e.stopShardProcessors()
e.sendDrainingHeartbeat()
return
case <-heartBeatTimer.Chan():
heartBeatTimer.Reset(backoff.JitDuration(e.heartBeatInterval, heartbeatJitterCoeff))
Expand Down Expand Up @@ -273,6 +276,10 @@ func (e *executorImpl[SP]) updateShardAssignmentMetered(ctx context.Context, sha
}

func (e *executorImpl[SP]) heartbeat(ctx context.Context) (shardAssignments map[string]*types.ShardAssignment, migrationMode types.MigrationMode, err error) {
return e.sendHeartbeat(ctx, types.ExecutorStatusACTIVE)
}

func (e *executorImpl[SP]) sendHeartbeat(ctx context.Context, status types.ExecutorStatus) (map[string]*types.ShardAssignment, types.MigrationMode, error) {
// Fill in the shard status reports
shardStatusReports := make(map[string]*types.ShardStatusReport)
e.managedProcessors.Range(func(shardID string, managedProcessor *managedProcessor[SP]) bool {
Expand All @@ -293,7 +300,7 @@ func (e *executorImpl[SP]) heartbeat(ctx context.Context) (shardAssignments map[
request := &types.ExecutorHeartbeatRequest{
Namespace: e.namespace,
ExecutorID: e.executorID,
Status: types.ExecutorStatusACTIVE,
Status: status,
ShardStatusReports: shardStatusReports,
Metadata: e.metadata.Get(),
}
Expand All @@ -318,6 +325,16 @@ func (e *executorImpl[SP]) heartbeat(ctx context.Context) (shardAssignments map[
return response.ShardAssignments, response.MigrationMode, nil
}

func (e *executorImpl[SP]) sendDrainingHeartbeat() {
ctx, cancel := context.WithTimeout(context.Background(), drainingHeartbeatTimeout)
defer cancel()

_, _, err := e.sendHeartbeat(ctx, types.ExecutorStatusDRAINING)
if err != nil {
e.logger.Error("failed to send draining heartbeat", tag.Error(err))
}
}

func (e *executorImpl[SP]) updateShardAssignment(ctx context.Context, shardAssignments map[string]*types.ShardAssignment) {
wg := sync.WaitGroup{}

Expand Down
Loading
Loading