From f2e581e2d78cb4bee70dda1723af9df1bfe1cb91 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Thu, 9 Jan 2025 16:46:56 +0800 Subject: [PATCH] enhance: add broadcast for streaming service (#39047) issue: #38399 pr: #39020 - Add new rpc for transfer broadcast to streaming coord - Add broadcast service at streaming coord to make broadcast message sent automicly also cherry pick the pr #38400 --------- Signed-off-by: chyezh --- internal/.mockery.yaml | 12 +- internal/distributed/streaming/append.go | 17 +- internal/distributed/streaming/streaming.go | 4 + .../distributed/streaming/streaming_test.go | 38 ++-- internal/distributed/streaming/wal.go | 26 ++- internal/distributed/streaming/wal_test.go | 54 ++++- internal/distributed/streamingnode/service.go | 89 +++++--- internal/metastore/catalog.go | 9 + .../metastore/kv/streamingcoord/constant.go | 5 +- .../metastore/kv/streamingcoord/kv_catalog.go | 47 +++- .../kv/streamingcoord/kv_catalog_test.go | 59 ++++- .../mock_streaming/mock_WALAccesser.go | 59 +++++ .../mock_StreamingCoordCataLog.go | 105 +++++++++ .../mock_client/mock_BroadcastService.go | 98 +++++++++ .../streamingcoord/mock_client/mock_Client.go | 47 ++++ .../mock_broadcaster/mock_AppendOperator.go | 100 +++++++++ .../client/broadcast/broadcast_impl.go | 56 +++++ internal/streamingcoord/client/client.go | 13 ++ internal/streamingcoord/client/client_impl.go | 6 + .../server/broadcaster/append_operator.go | 14 ++ .../server/broadcaster/broadcaster.go | 24 ++ .../server/broadcaster/broadcaster_impl.go | 207 ++++++++++++++++++ .../server/broadcaster/broadcaster_test.go | 142 ++++++++++++ .../streamingcoord/server/broadcaster/task.go | 126 +++++++++++ internal/streamingcoord/server/builder.go | 4 + .../server/resource/resource.go | 9 + internal/streamingcoord/server/server.go | 33 ++- .../server/service/broadcast.go | 44 ++++ internal/streamingnode/server/builder.go | 9 +- .../flusher/flusherimpl/channel_lifetime.go | 15 +- .../flusher/flusherimpl/flusher_impl.go | 93 +++++--- .../flusher/flusherimpl/flusher_impl_test.go | 38 ++-- .../flusher/flusherimpl/pipeline_params.go | 51 ----- .../streamingnode/server/resource/resource.go | 15 +- .../server/resource/resource_test.go | 9 +- .../server/resource/test_utility.go | 8 +- .../handler/producer/produce_server.go | 16 +- .../server/wal/adaptor/wal_test.go | 15 +- .../segment/manager/partition_manager.go | 6 +- .../segment/manager/pchannel_manager.go | 6 +- .../segment/manager/pchannel_manager_test.go | 11 +- .../wal/interceptors/timetick/ack/ack_test.go | 10 +- .../interceptors/timetick/timetick_message.go | 2 +- .../server/walmanager/manager_impl_test.go | 10 +- .../server/walmanager/wal_lifetime_test.go | 10 +- .../resource => util}/idalloc/allocator.go | 5 +- .../idalloc/allocator_test.go | 12 +- .../idalloc/basic_allocator.go | 21 +- .../idalloc/basic_allocator_test.go | 31 ++- .../resource => util}/idalloc/mallocator.go | 0 .../idalloc/test_mock_root_coord_client.go | 0 pkg/streaming/proto/messages.proto | 5 + pkg/streaming/proto/streaming.proto | 38 +++- pkg/streaming/util/message/builder.go | 90 +++++++- pkg/streaming/util/message/message.go | 28 ++- pkg/streaming/util/message/message_impl.go | 59 ++++- pkg/streaming/util/message/properties.go | 1 + pkg/streaming/util/types/streaming_node.go | 23 ++ pkg/util/contextutil/context_util.go | 12 + pkg/util/retry/options.go | 6 + pkg/util/retry/retry.go | 2 +- pkg/util/retry/retry_test.go | 11 + pkg/util/typeutil/backoff_timer.go | 46 ++++ 63 files changed, 1884 insertions(+), 277 deletions(-) create mode 100644 internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go create mode 100644 internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go create mode 100644 internal/streamingcoord/client/broadcast/broadcast_impl.go create mode 100644 internal/streamingcoord/server/broadcaster/append_operator.go create mode 100644 internal/streamingcoord/server/broadcaster/broadcaster.go create mode 100644 internal/streamingcoord/server/broadcaster/broadcaster_impl.go create mode 100644 internal/streamingcoord/server/broadcaster/broadcaster_test.go create mode 100644 internal/streamingcoord/server/broadcaster/task.go create mode 100644 internal/streamingcoord/server/service/broadcast.go delete mode 100644 internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go rename internal/{streamingnode/server/resource => util}/idalloc/allocator.go (94%) rename internal/{streamingnode/server/resource => util}/idalloc/allocator_test.go (81%) rename internal/{streamingnode/server/resource => util}/idalloc/basic_allocator.go (83%) rename internal/{streamingnode/server/resource => util}/idalloc/basic_allocator_test.go (84%) rename internal/{streamingnode/server/resource => util}/idalloc/mallocator.go (100%) rename internal/{streamingnode/server/resource => util}/idalloc/test_mock_root_coord_client.go (100%) diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index 2179959e87d3a..8804105c23322 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -12,12 +12,16 @@ packages: github.com/milvus-io/milvus/internal/streamingcoord/server/balancer: interfaces: Balancer: - github.com/milvus-io/milvus/internal/streamingnode/client/manager: + github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster: interfaces: - ManagerClient: + AppendOperator: github.com/milvus-io/milvus/internal/streamingcoord/client: interfaces: Client: + BroadcastService: + github.com/milvus-io/milvus/internal/streamingnode/client/manager: + interfaces: + ManagerClient: github.com/milvus-io/milvus/internal/streamingnode/client/handler: interfaces: HandlerClient: @@ -46,10 +50,10 @@ packages: InterceptorWithReady: InterceptorBuilder: github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector: - interfaces: + interfaces: SealOperator: github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector: - interfaces: + interfaces: TimeTickSyncOperator: google.golang.org/grpc: interfaces: diff --git a/internal/distributed/streaming/append.go b/internal/distributed/streaming/append.go index 2fd0820e549b2..b4193d8b94446 100644 --- a/internal/distributed/streaming/append.go +++ b/internal/distributed/streaming/append.go @@ -17,6 +17,12 @@ func (w *walAccesserImpl) appendToWAL(ctx context.Context, msg message.MutableMe return p.Produce(ctx, msg) } +func (w *walAccesserImpl) broadcastToWAL(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + // The broadcast operation will be sent to the coordinator. + // The coordinator will dispatch the message to all the vchannels with an eventual consistency guarantee. + return w.streamingCoordClient.Broadcast().Broadcast(ctx, msg) +} + // createOrGetProducer creates or get a producer. // vchannel in same pchannel can share the same producer. func (w *walAccesserImpl) getProducer(pchannel string) *producer.ResumableProducer { @@ -40,14 +46,19 @@ func assertValidMessage(msgs ...message.MutableMessage) { if msg.MessageType().IsSystem() { panic("system message is not allowed to append from client") } - } - for _, msg := range msgs { if msg.VChannel() == "" { - panic("vchannel is empty") + panic("we don't support sent all vchannel message at client now") } } } +// assertValidBroadcastMessage asserts the message is not system message. +func assertValidBroadcastMessage(msg message.BroadcastMutableMessage) { + if msg.MessageType().IsSystem() { + panic("system message is not allowed to broadcast append from client") + } +} + // We only support delete and insert message for txn now. func assertIsDmlMessage(msgs ...message.MutableMessage) { for _, msg := range msgs { diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index 8ef6df73619d0..8f45b66a3f518 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -85,6 +85,10 @@ type WALAccesser interface { // RawAppend writes a records to the log. RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) + // BroadcastAppend sends a broadcast message to all target vchannels. + // BroadcastAppend guarantees the atomicity written of the messages and eventual consistency. + BroadcastAppend(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) + // Read returns a scanner for reading records from the wal. Read(ctx context.Context, opts ReadOption) Scanner diff --git a/internal/distributed/streaming/streaming_test.go b/internal/distributed/streaming/streaming_test.go index c24f65261636d..e44e18e7c2005 100644 --- a/internal/distributed/streaming/streaming_test.go +++ b/internal/distributed/streaming/streaming_test.go @@ -14,7 +14,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/paramtable" ) -const vChannel = "by-dev-rootcoord-dml_4" +var vChannels = []string{ + "by-dev-rootcoord-dml_4", + "by-dev-rootcoord-dml_5", +} func TestMain(m *testing.M) { paramtable.Init() @@ -33,10 +36,11 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.CreateCollectionRequest{ CollectionID: 1, }). - WithVChannel(vChannel). - BuildMutable() - resp, err := streaming.WAL().RawAppend(context.Background(), msg) - fmt.Printf("%+v\t%+v\n", resp, err) + WithBroadcast(vChannels). + BuildBroadcast() + + resp, err := streaming.WAL().BroadcastAppend(context.Background(), msg) + t.Logf("CreateCollection: %+v\t%+v\n", resp, err) for i := 0; i < 500; i++ { time.Sleep(time.Millisecond * 1) @@ -47,17 +51,17 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.InsertRequest{ CollectionID: 1, }). - WithVChannel(vChannel). + WithVChannel(vChannels[0]). BuildMutable() resp, err := streaming.WAL().RawAppend(context.Background(), msg) - fmt.Printf("%+v\t%+v\n", resp, err) + t.Logf("Insert: %+v\t%+v\n", resp, err) } for i := 0; i < 500; i++ { time.Sleep(time.Millisecond * 1) txn, err := streaming.WAL().Txn(context.Background(), streaming.TxnOption{ - VChannel: vChannel, - Keepalive: 100 * time.Millisecond, + VChannel: vChannels[0], + Keepalive: 500 * time.Millisecond, }) if err != nil { t.Errorf("txn failed: %v", err) @@ -71,7 +75,7 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.InsertRequest{ CollectionID: 1, }). - WithVChannel(vChannel). + WithVChannel(vChannels[0]). BuildMutable() err := txn.Append(context.Background(), msg) fmt.Printf("%+v\n", err) @@ -80,7 +84,7 @@ func TestStreamingProduce(t *testing.T) { if err != nil { t.Errorf("txn failed: %v", err) } - fmt.Printf("%+v\n", result) + t.Logf("txn commit: %+v\n", result) } msg, _ = message.NewDropCollectionMessageBuilderV1(). @@ -90,10 +94,10 @@ func TestStreamingProduce(t *testing.T) { WithBody(&msgpb.DropCollectionRequest{ CollectionID: 1, }). - WithVChannel(vChannel). - BuildMutable() - resp, err = streaming.WAL().RawAppend(context.Background(), msg) - fmt.Printf("%+v\t%+v\n", resp, err) + WithBroadcast(vChannels). + BuildBroadcast() + resp, err = streaming.WAL().BroadcastAppend(context.Background(), msg) + t.Logf("DropCollection: %+v\t%+v\n", resp, err) } func TestStreamingConsume(t *testing.T) { @@ -102,7 +106,7 @@ func TestStreamingConsume(t *testing.T) { defer streaming.Release() ch := make(message.ChanMessageHandler, 10) s := streaming.WAL().Read(context.Background(), streaming.ReadOption{ - VChannel: vChannel, + VChannel: vChannels[0], DeliverPolicy: options.DeliverPolicyAll(), MessageHandler: ch, }) @@ -115,7 +119,7 @@ func TestStreamingConsume(t *testing.T) { time.Sleep(10 * time.Millisecond) select { case msg := <-ch: - fmt.Printf("msgID=%+v, msgType=%+v, tt=%d, lca=%+v, body=%s, idx=%d\n", + t.Logf("msgID=%+v, msgType=%+v, tt=%d, lca=%+v, body=%s, idx=%d\n", msg.MessageID(), msg.MessageType(), msg.TimeTick(), diff --git a/internal/distributed/streaming/wal.go b/internal/distributed/streaming/wal.go index 8caba0186bc05..e61f3edba7144 100644 --- a/internal/distributed/streaming/wal.go +++ b/internal/distributed/streaming/wal.go @@ -28,11 +28,11 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl { // Create a new streamingnode handler client. handlerClient := handler.NewHandlerClient(streamingCoordClient.Assignment()) return &walAccesserImpl{ - lifetime: typeutil.NewLifetime(), - streamingCoordAssignmentClient: streamingCoordClient, - handlerClient: handlerClient, - producerMutex: sync.Mutex{}, - producers: make(map[string]*producer.ResumableProducer), + lifetime: typeutil.NewLifetime(), + streamingCoordClient: streamingCoordClient, + handlerClient: handlerClient, + producerMutex: sync.Mutex{}, + producers: make(map[string]*producer.ResumableProducer), // TODO: optimize the pool size, use the streaming api but not goroutines. appendExecutionPool: conc.NewPool[struct{}](10), @@ -45,8 +45,8 @@ type walAccesserImpl struct { lifetime *typeutil.Lifetime // All services - streamingCoordAssignmentClient client.Client - handlerClient handler.HandlerClient + streamingCoordClient client.Client + handlerClient handler.HandlerClient producerMutex sync.Mutex producers map[string]*producer.ResumableProducer @@ -66,6 +66,16 @@ func (w *walAccesserImpl) RawAppend(ctx context.Context, msg message.MutableMess return w.appendToWAL(ctx, msg) } +func (w *walAccesserImpl) BroadcastAppend(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + assertValidBroadcastMessage(msg) + if !w.lifetime.Add(typeutil.LifetimeStateWorking) { + return nil, ErrWALAccesserClosed + } + defer w.lifetime.Done() + + return w.broadcastToWAL(ctx, msg) +} + // Read returns a scanner for reading records from the wal. func (w *walAccesserImpl) Read(_ context.Context, opts ReadOption) Scanner { if !w.lifetime.Add(typeutil.LifetimeStateWorking) { @@ -144,7 +154,7 @@ func (w *walAccesserImpl) Close() { w.producerMutex.Unlock() w.handlerClient.Close() - w.streamingCoordAssignmentClient.Close() + w.streamingCoordClient.Close() } // newErrScanner creates a scanner that returns an error. diff --git a/internal/distributed/streaming/wal_test.go b/internal/distributed/streaming/wal_test.go index db527c044eddb..a850b9cce3a07 100644 --- a/internal/distributed/streaming/wal_test.go +++ b/internal/distributed/streaming/wal_test.go @@ -30,19 +30,33 @@ const ( func TestWAL(t *testing.T) { coordClient := mock_client.NewMockClient(t) coordClient.EXPECT().Close().Return() + broadcastServce := mock_client.NewMockBroadcastService(t) + broadcastServce.EXPECT().Broadcast(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, bmm message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + result := make(map[string]*types.AppendResult) + for idx, msg := range bmm.SplitIntoMutableMessage() { + result[msg.VChannel()] = &types.AppendResult{ + MessageID: walimplstest.NewTestMessageID(int64(idx)), + TimeTick: uint64(time.Now().UnixMilli()), + } + } + return &types.BroadcastAppendResult{ + AppendResults: result, + }, nil + }) + coordClient.EXPECT().Broadcast().Return(broadcastServce) handler := mock_handler.NewMockHandlerClient(t) handler.EXPECT().Close().Return() w := &walAccesserImpl{ - lifetime: typeutil.NewLifetime(), - streamingCoordAssignmentClient: coordClient, - handlerClient: handler, - producerMutex: sync.Mutex{}, - producers: make(map[string]*producer.ResumableProducer), - appendExecutionPool: conc.NewPool[struct{}](10), - dispatchExecutionPool: conc.NewPool[struct{}](10), + lifetime: typeutil.NewLifetime(), + streamingCoordClient: coordClient, + handlerClient: handler, + producerMutex: sync.Mutex{}, + producers: make(map[string]*producer.ResumableProducer), + appendExecutionPool: conc.NewPool[struct{}](10), + dispatchExecutionPool: conc.NewPool[struct{}](10), } - defer w.Close() ctx := context.Background() @@ -114,6 +128,18 @@ func TestWAL(t *testing.T) { newInsertMessage(vChannel3), ) assert.NoError(t, resp.UnwrapFirstError()) + + r, err := w.BroadcastAppend(ctx, newBroadcastMessage([]string{vChannel1, vChannel2, vChannel3})) + assert.NoError(t, err) + assert.Len(t, r.AppendResults, 3) + + w.Close() + + resp = w.AppendMessages(ctx, newInsertMessage(vChannel1)) + assert.Error(t, resp.UnwrapFirstError()) + r, err = w.BroadcastAppend(ctx, newBroadcastMessage([]string{vChannel1, vChannel2, vChannel3})) + assert.Error(t, err) + assert.Nil(t, r) } func newInsertMessage(vChannel string) message.MutableMessage { @@ -127,3 +153,15 @@ func newInsertMessage(vChannel string) message.MutableMessage { } return msg } + +func newBroadcastMessage(vchannels []string) message.BroadcastMutableMessage { + msg, err := message.NewDropCollectionMessageBuilderV1(). + WithBroadcast(vchannels). + WithHeader(&message.DropCollectionMessageHeader{}). + WithBody(&msgpb.DropCollectionRequest{}). + BuildBroadcast() + if err != nil { + panic(err) + } + return msg +} diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 59cbc3c9a26d7..2e50721ed2d38 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -55,6 +55,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/netutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/tikv" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -83,8 +85,8 @@ type Server struct { // component client etcdCli *clientv3.Client tikvCli *txnkv.Client - rootCoord types.RootCoordClient - dataCoord types.DataCoordClient + rootCoord *syncutil.Future[types.RootCoordClient] + dataCoord *syncutil.Future[types.DataCoordClient] chunkManager storage.ChunkManager componentState *componentutil.ComponentStateService } @@ -95,6 +97,8 @@ func NewServer(ctx context.Context, f dependency.Factory) (*Server, error) { return &Server{ stopOnce: sync.Once{}, factory: f, + dataCoord: syncutil.NewFuture[types.DataCoordClient](), + rootCoord: syncutil.NewFuture[types.RootCoordClient](), grpcServerChan: make(chan struct{}), componentState: componentutil.NewComponentStateService(typeutil.StreamingNodeRole), ctx: ctx1, @@ -166,8 +170,17 @@ func (s *Server) stop() { // Stop rootCoord client. log.Info("streamingnode stop rootCoord client...") - if err := s.rootCoord.Close(); err != nil { - log.Warn("streamingnode stop rootCoord client failed", zap.Error(err)) + if s.rootCoord.Ready() { + if err := s.rootCoord.Get().Close(); err != nil { + log.Warn("streamingnode stop rootCoord client failed", zap.Error(err)) + } + } + + log.Info("streamingnode stop dataCoord client...") + if s.dataCoord.Ready() { + if err := s.dataCoord.Get().Close(); err != nil { + log.Warn("streamingnode stop dataCoord client failed", zap.Error(err)) + } } // Stop tikv @@ -216,12 +229,8 @@ func (s *Server) init() (err error) { if err := s.initSession(); err != nil { return err } - if err := s.initRootCoord(); err != nil { - return err - } - if err := s.initDataCoord(); err != nil { - return err - } + s.initRootCoord() + s.initDataCoord() s.initGRPCServer() // Create StreamingNode service. @@ -300,36 +309,48 @@ func (s *Server) initMeta() error { return nil } -func (s *Server) initRootCoord() (err error) { +func (s *Server) initRootCoord() { log := log.Ctx(s.ctx) - log.Info("StreamingNode connect to rootCoord...") - s.rootCoord, err = rcc.NewClient(s.ctx) - if err != nil { - return errors.Wrap(err, "StreamingNode try to new RootCoord client failed") - } + go func() { + retry.Do(s.ctx, func() error { + log.Info("StreamingNode connect to rootCoord...") + rootCoord, err := rcc.NewClient(s.ctx) + if err != nil { + return errors.Wrap(err, "StreamingNode try to new RootCoord client failed") + } - log.Info("StreamingNode try to wait for RootCoord ready") - err = componentutil.WaitForComponentHealthy(s.ctx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) - if err != nil { - return errors.Wrap(err, "StreamingNode wait for RootCoord ready failed") - } - return nil + log.Info("StreamingNode try to wait for RootCoord ready") + err = componentutil.WaitForComponentHealthy(s.ctx, rootCoord, "RootCoord", 1000000, time.Millisecond*200) + if err != nil { + return errors.Wrap(err, "StreamingNode wait for RootCoord ready failed") + } + log.Info("StreamingNode wait for RootCoord done") + s.rootCoord.Set(rootCoord) + return nil + }, retry.AttemptAlways()) + }() } -func (s *Server) initDataCoord() (err error) { +func (s *Server) initDataCoord() { log := log.Ctx(s.ctx) - log.Info("StreamingNode connect to dataCoord...") - s.dataCoord, err = dcc.NewClient(s.ctx) - if err != nil { - return errors.Wrap(err, "StreamingNode try to new DataCoord client failed") - } + go func() { + retry.Do(s.ctx, func() error { + log.Info("StreamingNode connect to dataCoord...") + dataCoord, err := dcc.NewClient(s.ctx) + if err != nil { + return errors.Wrap(err, "StreamingNode try to new DataCoord client failed") + } - log.Info("StreamingNode try to wait for DataCoord ready") - err = componentutil.WaitForComponentHealthy(s.ctx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) - if err != nil { - return errors.Wrap(err, "StreamingNode wait for DataCoord ready failed") - } - return nil + log.Info("StreamingNode try to wait for DataCoord ready") + err = componentutil.WaitForComponentHealthy(s.ctx, dataCoord, "DataCoord", 1000000, time.Millisecond*200) + if err != nil { + return errors.Wrap(err, "StreamingNode wait for DataCoord ready failed") + } + log.Info("StreamingNode wait for DataCoord ready") + s.dataCoord.Set(dataCoord) + return nil + }, retry.AttemptAlways()) + }() } func (s *Server) initChunkManager() (err error) { diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 090296d11bf1d..c7a2042dd73f0 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -210,6 +210,15 @@ type StreamingCoordCataLog interface { // SavePChannel save a pchannel info to metastore. SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error + + // ListBroadcastTask list all broadcast tasks. + // Used to recovery the broadcast tasks. + ListBroadcastTask(ctx context.Context) ([]*streamingpb.BroadcastTask, error) + + // SaveBroadcastTask save the broadcast task to metastore. + // Make the task recoverable after restart. + // When broadcast task is done, it will be removed from metastore. + SaveBroadcastTask(ctx context.Context, task *streamingpb.BroadcastTask) error } // StreamingNodeCataLog is the interface for streamingnode catalog diff --git a/internal/metastore/kv/streamingcoord/constant.go b/internal/metastore/kv/streamingcoord/constant.go index 5ae1f85b7d6bc..1f92dc9977da5 100644 --- a/internal/metastore/kv/streamingcoord/constant.go +++ b/internal/metastore/kv/streamingcoord/constant.go @@ -1,6 +1,7 @@ package streamingcoord const ( - MetaPrefix = "streamingcoord-meta" - PChannelMeta = MetaPrefix + "/pchannel" + MetaPrefix = "streamingcoord-meta/" + PChannelMetaPrefix = MetaPrefix + "pchannel/" + BroadcastTaskPrefix = MetaPrefix + "broadcast-task/" ) diff --git a/internal/metastore/kv/streamingcoord/kv_catalog.go b/internal/metastore/kv/streamingcoord/kv_catalog.go index d3d804052e026..c0a16a525106e 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog.go @@ -2,6 +2,7 @@ package streamingcoord import ( "context" + "strconv" "github.com/cockroachdb/errors" "google.golang.org/protobuf/proto" @@ -14,6 +15,14 @@ import ( ) // NewCataLog creates a new catalog instance +// streamingcoord-meta +// ├── broadcast +// │   ├── task-1 +// │   └── task-2 +// └── pchannel +// +// ├── pchannel-1 +// └── pchannel-2 func NewCataLog(metaKV kv.MetaKv) metastore.StreamingCoordCataLog { return &catalog{ metaKV: metaKV, @@ -27,7 +36,7 @@ type catalog struct { // ListPChannels returns all pchannels func (c *catalog) ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { - keys, values, err := c.metaKV.LoadWithPrefix(ctx, PChannelMeta) + keys, values, err := c.metaKV.LoadWithPrefix(ctx, PChannelMetaPrefix) if err != nil { return nil, err } @@ -60,7 +69,41 @@ func (c *catalog) SavePChannels(ctx context.Context, infos []*streamingpb.PChann }) } +func (c *catalog) ListBroadcastTask(ctx context.Context) ([]*streamingpb.BroadcastTask, error) { + keys, values, err := c.metaKV.LoadWithPrefix(ctx, BroadcastTaskPrefix) + if err != nil { + return nil, err + } + infos := make([]*streamingpb.BroadcastTask, 0, len(values)) + for k, value := range values { + info := &streamingpb.BroadcastTask{} + err = proto.Unmarshal([]byte(value), info) + if err != nil { + return nil, errors.Wrapf(err, "unmarshal broadcast task %s failed", keys[k]) + } + infos = append(infos, info) + } + return infos, nil +} + +func (c *catalog) SaveBroadcastTask(ctx context.Context, task *streamingpb.BroadcastTask) error { + key := buildBroadcastTaskPath(task.TaskId) + if task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE { + return c.metaKV.Remove(ctx, key) + } + v, err := proto.Marshal(task) + if err != nil { + return errors.Wrapf(err, "marshal broadcast task failed") + } + return c.metaKV.Save(ctx, key, string(v)) +} + // buildPChannelInfoPath builds the path for pchannel info. func buildPChannelInfoPath(name string) string { - return PChannelMeta + "/" + name + return PChannelMetaPrefix + name +} + +// buildBroadcastTaskPath builds the path for broadcast task. +func buildBroadcastTaskPath(id int64) string { + return BroadcastTaskPrefix + strconv.FormatInt(id, 10) } diff --git a/internal/metastore/kv/streamingcoord/kv_catalog_test.go b/internal/metastore/kv/streamingcoord/kv_catalog_test.go index 227ad0469bca3..215aee3d15ee3 100644 --- a/internal/metastore/kv/streamingcoord/kv_catalog_test.go +++ b/internal/metastore/kv/streamingcoord/kv_catalog_test.go @@ -2,6 +2,7 @@ package streamingcoord import ( "context" + "strings" "testing" "github.com/cockroachdb/errors" @@ -20,8 +21,10 @@ func TestCatalog(t *testing.T) { keys := make([]string, 0, len(kvStorage)) vals := make([]string, 0, len(kvStorage)) for k, v := range kvStorage { - keys = append(keys, k) - vals = append(vals, v) + if strings.HasPrefix(k, s) { + keys = append(keys, k) + vals = append(vals, v) + } } return keys, vals, nil }) @@ -31,12 +34,21 @@ func TestCatalog(t *testing.T) { } return nil }) + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key, value string) error { + kvStorage[key] = value + return nil + }) + kv.EXPECT().Remove(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key string) error { + delete(kvStorage, key) + return nil + }) catalog := NewCataLog(kv) metas, err := catalog.ListPChannel(context.Background()) assert.NoError(t, err) assert.Empty(t, metas) + // PChannel test err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{ { Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1}, @@ -53,6 +65,37 @@ func TestCatalog(t *testing.T) { assert.NoError(t, err) assert.Len(t, metas, 2) + // BroadcastTask test + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{ + TaskId: 1, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + }) + assert.NoError(t, err) + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{ + TaskId: 2, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + }) + assert.NoError(t, err) + + tasks, err := catalog.ListBroadcastTask(context.Background()) + assert.NoError(t, err) + assert.Len(t, tasks, 2) + for _, task := range tasks { + assert.Equal(t, streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, task.State) + } + + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{ + TaskId: 1, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE, + }) + assert.NoError(t, err) + tasks, err = catalog.ListBroadcastTask(context.Background()) + assert.NoError(t, err) + assert.Len(t, tasks, 1) + for _, task := range tasks { + assert.Equal(t, streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, task.State) + } + // error path. kv.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Unset() kv.EXPECT().LoadWithPrefix(mock.Anything, mock.Anything).Return(nil, nil, errors.New("load error")) @@ -60,7 +103,19 @@ func TestCatalog(t *testing.T) { assert.Error(t, err) assert.Nil(t, metas) + tasks, err = catalog.ListBroadcastTask(context.Background()) + assert.Error(t, err) + assert.Nil(t, tasks) + kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Unset() kv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(errors.New("save error")) + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Unset() + kv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("save error")) + err = catalog.SavePChannels(context.Background(), []*streamingpb.PChannelMeta{{ + Channel: &streamingpb.PChannelInfo{Name: "test", Term: 1}, + Node: &streamingpb.StreamingNodeInfo{ServerId: 1}, + }}) + assert.Error(t, err) + err = catalog.SaveBroadcastTask(context.Background(), &streamingpb.BroadcastTask{}) assert.Error(t, err) } diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index 6b391ff89a062..43cf731783f38 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -149,6 +149,65 @@ func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(c return _c } +// BroadcastAppend provides a mock function with given fields: ctx, msg +func (_m *MockWALAccesser) BroadcastAppend(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for BroadcastAppend") + } + + var r0 *types.BroadcastAppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BroadcastAppendResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockWALAccesser_BroadcastAppend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BroadcastAppend' +type MockWALAccesser_BroadcastAppend_Call struct { + *mock.Call +} + +// BroadcastAppend is a helper method to define mock.On call +// - ctx context.Context +// - msg message.BroadcastMutableMessage +func (_e *MockWALAccesser_Expecter) BroadcastAppend(ctx interface{}, msg interface{}) *MockWALAccesser_BroadcastAppend_Call { + return &MockWALAccesser_BroadcastAppend_Call{Call: _e.mock.On("BroadcastAppend", ctx, msg)} +} + +func (_c *MockWALAccesser_BroadcastAppend_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockWALAccesser_BroadcastAppend_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage)) + }) + return _c +} + +func (_c *MockWALAccesser_BroadcastAppend_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockWALAccesser_BroadcastAppend_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockWALAccesser_BroadcastAppend_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockWALAccesser_BroadcastAppend_Call { + _c.Call.Return(run) + return _c +} + // RawAppend provides a mock function with given fields: ctx, msgs, opts func (_m *MockWALAccesser) RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption) (*types.AppendResult, error) { _va := make([]interface{}, len(opts)) diff --git a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go index b0bc3b77756d0..651554d48b3f3 100644 --- a/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go +++ b/internal/mocks/mock_metastore/mock_StreamingCoordCataLog.go @@ -23,6 +23,64 @@ func (_m *MockStreamingCoordCataLog) EXPECT() *MockStreamingCoordCataLog_Expecte return &MockStreamingCoordCataLog_Expecter{mock: &_m.Mock} } +// ListBroadcastTask provides a mock function with given fields: ctx +func (_m *MockStreamingCoordCataLog) ListBroadcastTask(ctx context.Context) ([]*streamingpb.BroadcastTask, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListBroadcastTask") + } + + var r0 []*streamingpb.BroadcastTask + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*streamingpb.BroadcastTask, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*streamingpb.BroadcastTask); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*streamingpb.BroadcastTask) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingCoordCataLog_ListBroadcastTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListBroadcastTask' +type MockStreamingCoordCataLog_ListBroadcastTask_Call struct { + *mock.Call +} + +// ListBroadcastTask is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockStreamingCoordCataLog_Expecter) ListBroadcastTask(ctx interface{}) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + return &MockStreamingCoordCataLog_ListBroadcastTask_Call{Call: _e.mock.On("ListBroadcastTask", ctx)} +} + +func (_c *MockStreamingCoordCataLog_ListBroadcastTask_Call) Run(run func(ctx context.Context)) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockStreamingCoordCataLog_ListBroadcastTask_Call) Return(_a0 []*streamingpb.BroadcastTask, _a1 error) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingCoordCataLog_ListBroadcastTask_Call) RunAndReturn(run func(context.Context) ([]*streamingpb.BroadcastTask, error)) *MockStreamingCoordCataLog_ListBroadcastTask_Call { + _c.Call.Return(run) + return _c +} + // ListPChannel provides a mock function with given fields: ctx func (_m *MockStreamingCoordCataLog) ListPChannel(ctx context.Context) ([]*streamingpb.PChannelMeta, error) { ret := _m.Called(ctx) @@ -81,6 +139,53 @@ func (_c *MockStreamingCoordCataLog_ListPChannel_Call) RunAndReturn(run func(con return _c } +// SaveBroadcastTask provides a mock function with given fields: ctx, task +func (_m *MockStreamingCoordCataLog) SaveBroadcastTask(ctx context.Context, task *streamingpb.BroadcastTask) error { + ret := _m.Called(ctx, task) + + if len(ret) == 0 { + panic("no return value specified for SaveBroadcastTask") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *streamingpb.BroadcastTask) error); ok { + r0 = rf(ctx, task) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingCoordCataLog_SaveBroadcastTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveBroadcastTask' +type MockStreamingCoordCataLog_SaveBroadcastTask_Call struct { + *mock.Call +} + +// SaveBroadcastTask is a helper method to define mock.On call +// - ctx context.Context +// - task *streamingpb.BroadcastTask +func (_e *MockStreamingCoordCataLog_Expecter) SaveBroadcastTask(ctx interface{}, task interface{}) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + return &MockStreamingCoordCataLog_SaveBroadcastTask_Call{Call: _e.mock.On("SaveBroadcastTask", ctx, task)} +} + +func (_c *MockStreamingCoordCataLog_SaveBroadcastTask_Call) Run(run func(ctx context.Context, task *streamingpb.BroadcastTask)) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*streamingpb.BroadcastTask)) + }) + return _c +} + +func (_c *MockStreamingCoordCataLog_SaveBroadcastTask_Call) Return(_a0 error) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingCoordCataLog_SaveBroadcastTask_Call) RunAndReturn(run func(context.Context, *streamingpb.BroadcastTask) error) *MockStreamingCoordCataLog_SaveBroadcastTask_Call { + _c.Call.Return(run) + return _c +} + // SavePChannels provides a mock function with given fields: ctx, info func (_m *MockStreamingCoordCataLog) SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error { ret := _m.Called(ctx, info) diff --git a/internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go b/internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go new file mode 100644 index 0000000000000..3c84e0cce1f5d --- /dev/null +++ b/internal/mocks/streamingcoord/mock_client/mock_BroadcastService.go @@ -0,0 +1,98 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_client + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// MockBroadcastService is an autogenerated mock type for the BroadcastService type +type MockBroadcastService struct { + mock.Mock +} + +type MockBroadcastService_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBroadcastService) EXPECT() *MockBroadcastService_Expecter { + return &MockBroadcastService_Expecter{mock: &_m.Mock} +} + +// Broadcast provides a mock function with given fields: ctx, msg +func (_m *MockBroadcastService) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + ret := _m.Called(ctx, msg) + + if len(ret) == 0 { + panic("no return value specified for Broadcast") + } + + var r0 *types.BroadcastAppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastMutableMessage) *types.BroadcastAppendResult); ok { + r0 = rf(ctx, msg) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.BroadcastAppendResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, message.BroadcastMutableMessage) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroadcastService_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast' +type MockBroadcastService_Broadcast_Call struct { + *mock.Call +} + +// Broadcast is a helper method to define mock.On call +// - ctx context.Context +// - msg message.BroadcastMutableMessage +func (_e *MockBroadcastService_Expecter) Broadcast(ctx interface{}, msg interface{}) *MockBroadcastService_Broadcast_Call { + return &MockBroadcastService_Broadcast_Call{Call: _e.mock.On("Broadcast", ctx, msg)} +} + +func (_c *MockBroadcastService_Broadcast_Call) Run(run func(ctx context.Context, msg message.BroadcastMutableMessage)) *MockBroadcastService_Broadcast_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(message.BroadcastMutableMessage)) + }) + return _c +} + +func (_c *MockBroadcastService_Broadcast_Call) Return(_a0 *types.BroadcastAppendResult, _a1 error) *MockBroadcastService_Broadcast_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroadcastService_Broadcast_Call) RunAndReturn(run func(context.Context, message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error)) *MockBroadcastService_Broadcast_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBroadcastService creates a new instance of MockBroadcastService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBroadcastService(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBroadcastService { + mock := &MockBroadcastService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingcoord/mock_client/mock_Client.go b/internal/mocks/streamingcoord/mock_client/mock_Client.go index 02923644d235c..574e08d01533e 100644 --- a/internal/mocks/streamingcoord/mock_client/mock_Client.go +++ b/internal/mocks/streamingcoord/mock_client/mock_Client.go @@ -67,6 +67,53 @@ func (_c *MockClient_Assignment_Call) RunAndReturn(run func() client.AssignmentS return _c } +// Broadcast provides a mock function with given fields: +func (_m *MockClient) Broadcast() client.BroadcastService { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Broadcast") + } + + var r0 client.BroadcastService + if rf, ok := ret.Get(0).(func() client.BroadcastService); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.BroadcastService) + } + } + + return r0 +} + +// MockClient_Broadcast_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Broadcast' +type MockClient_Broadcast_Call struct { + *mock.Call +} + +// Broadcast is a helper method to define mock.On call +func (_e *MockClient_Expecter) Broadcast() *MockClient_Broadcast_Call { + return &MockClient_Broadcast_Call{Call: _e.mock.On("Broadcast")} +} + +func (_c *MockClient_Broadcast_Call) Run(run func()) *MockClient_Broadcast_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Broadcast_Call) Return(_a0 client.BroadcastService) *MockClient_Broadcast_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Broadcast_Call) RunAndReturn(run func() client.BroadcastService) *MockClient_Broadcast_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockClient) Close() { _m.Called() diff --git a/internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go new file mode 100644 index 0000000000000..8f049c5616cf6 --- /dev/null +++ b/internal/mocks/streamingcoord/server/mock_broadcaster/mock_AppendOperator.go @@ -0,0 +1,100 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_broadcaster + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + streaming "github.com/milvus-io/milvus/internal/distributed/streaming" +) + +// MockAppendOperator is an autogenerated mock type for the AppendOperator type +type MockAppendOperator struct { + mock.Mock +} + +type MockAppendOperator_Expecter struct { + mock *mock.Mock +} + +func (_m *MockAppendOperator) EXPECT() *MockAppendOperator_Expecter { + return &MockAppendOperator_Expecter{mock: &_m.Mock} +} + +// AppendMessages provides a mock function with given fields: ctx, msgs +func (_m *MockAppendOperator) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { + _va := make([]interface{}, len(msgs)) + for _i := range msgs { + _va[_i] = msgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AppendMessages") + } + + var r0 streaming.AppendResponses + if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok { + r0 = rf(ctx, msgs...) + } else { + r0 = ret.Get(0).(streaming.AppendResponses) + } + + return r0 +} + +// MockAppendOperator_AppendMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessages' +type MockAppendOperator_AppendMessages_Call struct { + *mock.Call +} + +// AppendMessages is a helper method to define mock.On call +// - ctx context.Context +// - msgs ...message.MutableMessage +func (_e *MockAppendOperator_Expecter) AppendMessages(ctx interface{}, msgs ...interface{}) *MockAppendOperator_AppendMessages_Call { + return &MockAppendOperator_AppendMessages_Call{Call: _e.mock.On("AppendMessages", + append([]interface{}{ctx}, msgs...)...)} +} + +func (_c *MockAppendOperator_AppendMessages_Call) Run(run func(ctx context.Context, msgs ...message.MutableMessage)) *MockAppendOperator_AppendMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.MutableMessage, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(message.MutableMessage) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockAppendOperator_AppendMessages_Call) Return(_a0 streaming.AppendResponses) *MockAppendOperator_AppendMessages_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockAppendOperator_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockAppendOperator_AppendMessages_Call { + _c.Call.Return(run) + return _c +} + +// NewMockAppendOperator creates a new instance of MockAppendOperator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockAppendOperator(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAppendOperator { + mock := &MockAppendOperator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/streamingcoord/client/broadcast/broadcast_impl.go b/internal/streamingcoord/client/broadcast/broadcast_impl.go new file mode 100644 index 0000000000000..b6296748d1eba --- /dev/null +++ b/internal/streamingcoord/client/broadcast/broadcast_impl.go @@ -0,0 +1,56 @@ +package broadcast + +import ( + "context" + + "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// NewBroadcastService creates a new broadcast service. +func NewBroadcastService(walName string, service lazygrpc.Service[streamingpb.StreamingCoordBroadcastServiceClient]) *BroadcastServiceImpl { + return &BroadcastServiceImpl{ + walName: walName, + service: service, + } +} + +// BroadcastServiceImpl is the implementation of BroadcastService. +type BroadcastServiceImpl struct { + walName string + service lazygrpc.Service[streamingpb.StreamingCoordBroadcastServiceClient] +} + +// Broadcast sends a broadcast message to the streaming coord to perform a broadcast. +func (c *BroadcastServiceImpl) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { + client, err := c.service.GetService(ctx) + if err != nil { + return nil, err + } + resp, err := client.Broadcast(ctx, &streamingpb.BroadcastRequest{ + Message: &messagespb.Message{ + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + }, + }) + if err != nil { + return nil, err + } + results := make(map[string]*types.AppendResult, len(resp.Results)) + for channel, result := range resp.Results { + msgID, err := message.UnmarshalMessageID(c.walName, result.Id.Id) + if err != nil { + return nil, err + } + results[channel] = &types.AppendResult{ + MessageID: msgID, + TimeTick: result.GetTimetick(), + TxnCtx: message.NewTxnContextFromProto(result.GetTxnContext()), + Extra: result.GetExtra(), + } + } + return &types.BroadcastAppendResult{AppendResults: results}, nil +} diff --git a/internal/streamingcoord/client/client.go b/internal/streamingcoord/client/client.go index 83a55fd107159..07f0937360bfd 100644 --- a/internal/streamingcoord/client/client.go +++ b/internal/streamingcoord/client/client.go @@ -11,12 +11,15 @@ import ( "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/streamingcoord/client/assignment" + "github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" + "github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/interceptor" @@ -32,8 +35,16 @@ type AssignmentService interface { types.AssignmentDiscoverWatcher } +// BroadcastService is the interface of broadcast service. +type BroadcastService interface { + // Broadcast sends a broadcast message to the streaming service. + Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) +} + // Client is the interface of log service client. type Client interface { + Broadcast() BroadcastService + // Assignment access assignment service. Assignment() AssignmentService @@ -58,10 +69,12 @@ func NewClient(etcdCli *clientv3.Client) Client { ) }) assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient) + broadcastService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordBroadcastServiceClient) return &clientImpl{ conn: conn, rb: rb, assignmentService: assignment.NewAssignmentService(assignmentService), + broadcastService: broadcast.NewBroadcastService(util.MustSelectWALName(), broadcastService), } } diff --git a/internal/streamingcoord/client/client_impl.go b/internal/streamingcoord/client/client_impl.go index ffb0b0355a3a5..88c94794e1c4d 100644 --- a/internal/streamingcoord/client/client_impl.go +++ b/internal/streamingcoord/client/client_impl.go @@ -2,6 +2,7 @@ package client import ( "github.com/milvus-io/milvus/internal/streamingcoord/client/assignment" + "github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" ) @@ -11,6 +12,11 @@ type clientImpl struct { conn lazygrpc.Conn rb resolver.Builder assignmentService *assignment.AssignmentServiceImpl + broadcastService *broadcast.BroadcastServiceImpl +} + +func (c *clientImpl) Broadcast() BroadcastService { + return c.broadcastService } // Assignment access assignment service. diff --git a/internal/streamingcoord/server/broadcaster/append_operator.go b/internal/streamingcoord/server/broadcaster/append_operator.go new file mode 100644 index 0000000000000..ec849ea2be917 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/append_operator.go @@ -0,0 +1,14 @@ +package broadcaster + +import ( + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/util/streamingutil" +) + +// NewAppendOperator creates an append operator to handle the incoming messages for broadcaster. +func NewAppendOperator() AppendOperator { + if streamingutil.IsStreamingServiceEnabled() { + return streaming.WAL() + } + return nil +} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster.go b/internal/streamingcoord/server/broadcaster/broadcaster.go new file mode 100644 index 0000000000000..79e77bb8829cf --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster.go @@ -0,0 +1,24 @@ +package broadcaster + +import ( + "context" + + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +type Broadcaster interface { + // Broadcast broadcasts the message to all channels. + Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) + + // Close closes the broadcaster. + Close() +} + +// AppendOperator is used to append messages, there's only two implement of this interface: +// 1. streaming.WAL() +// 2. old msgstream interface +type AppendOperator interface { + AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses +} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_impl.go b/internal/streamingcoord/server/broadcaster/broadcaster_impl.go new file mode 100644 index 0000000000000..2da0e0679f907 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster_impl.go @@ -0,0 +1,207 @@ +package broadcaster + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/contextutil" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func RecoverBroadcaster( + ctx context.Context, + appendOperator AppendOperator, +) (Broadcaster, error) { + logger := resource.Resource().Logger().With(log.FieldComponent("broadcaster")) + tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx) + if err != nil { + return nil, err + } + pendings := make([]*broadcastTask, 0, len(tasks)) + for _, task := range tasks { + if task.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING { + // recover pending task + t := newTask(task, logger) + pendings = append(pendings, t) + } + } + b := &broadcasterImpl{ + logger: logger, + lifetime: typeutil.NewLifetime(), + backgroundTaskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + pendings: pendings, + backoffs: typeutil.NewHeap[*broadcastTask](&broadcastTaskArray{}), + backoffChan: make(chan *broadcastTask), + pendingChan: make(chan *broadcastTask), + workerChan: make(chan *broadcastTask), + appendOperator: appendOperator, + } + go b.execute() + return b, nil +} + +// broadcasterImpl is the implementation of Broadcaster +type broadcasterImpl struct { + logger *log.MLogger + lifetime *typeutil.Lifetime + backgroundTaskNotifier *syncutil.AsyncTaskNotifier[struct{}] + pendings []*broadcastTask + backoffs typeutil.Heap[*broadcastTask] + pendingChan chan *broadcastTask + backoffChan chan *broadcastTask + workerChan chan *broadcastTask + appendOperator AppendOperator +} + +// Broadcast broadcasts the message to all channels. +func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMutableMessage) (result *types.BroadcastAppendResult, err error) { + if !b.lifetime.Add(typeutil.LifetimeStateWorking) { + return nil, status.NewOnShutdownError("broadcaster is closing") + } + defer func() { + if err != nil { + b.logger.Warn("broadcast message failed", zap.Error(err)) + return + } + }() + + // Once the task is persisted, it must be successful. + task, err := b.persistBroadcastTask(ctx, msg) + if err != nil { + return nil, err + } + t := newTask(task, b.logger) + select { + case <-b.backgroundTaskNotifier.Context().Done(): + // We can only check the background context but not the request context here. + // Because we want the new incoming task must be delivered to the background task queue + // otherwise the broadcaster is closing + return nil, status.NewOnShutdownError("broadcaster is closing") + case b.pendingChan <- t: + } + + // Wait both request context and the background task context. + ctx, _ = contextutil.MergeContext(ctx, b.backgroundTaskNotifier.Context()) + return t.BlockUntilTaskDone(ctx) +} + +// persistBroadcastTask persists the broadcast task into catalog. +func (b *broadcasterImpl) persistBroadcastTask(ctx context.Context, msg message.BroadcastMutableMessage) (*streamingpb.BroadcastTask, error) { + defer b.lifetime.Done() + + id, err := resource.Resource().IDAllocator().Allocate(ctx) + if err != nil { + return nil, status.NewInner("allocate new id failed, %s", err.Error()) + } + task := &streamingpb.BroadcastTask{ + TaskId: int64(id), + Message: &messagespb.Message{Payload: msg.Payload(), Properties: msg.Properties().ToRawMap()}, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + } + // Save the task into catalog to help recovery. + if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, task); err != nil { + return nil, status.NewInner("save broadcast task failed, %s", err.Error()) + } + return task, nil +} + +func (b *broadcasterImpl) Close() { + b.lifetime.SetState(typeutil.LifetimeStateStopped) + b.lifetime.Wait() + + b.backgroundTaskNotifier.Cancel() + b.backgroundTaskNotifier.BlockUntilFinish() +} + +// execute the broadcaster +func (b *broadcasterImpl) execute() { + b.logger.Info("broadcaster start to execute") + defer func() { + b.backgroundTaskNotifier.Finish(struct{}{}) + b.logger.Info("broadcaster execute exit") + }() + + // Start n workers to handle the broadcast task. + wg := sync.WaitGroup{} + for i := 0; i < 4; i++ { + i := i + // Start n workers to handle the broadcast task. + wg.Add(1) + go func() { + defer wg.Done() + b.worker(i) + }() + } + defer wg.Wait() + + b.dispatch() +} + +func (b *broadcasterImpl) dispatch() { + for { + var workerChan chan *broadcastTask + var nextTask *broadcastTask + var nextBackOff <-chan time.Time + // Wait for new task. + if len(b.pendings) > 0 { + workerChan = b.workerChan + nextTask = b.pendings[0] + } + if b.backoffs.Len() > 0 { + var nextInterval time.Duration + nextBackOff, nextInterval = b.backoffs.Peek().NextTimer() + b.logger.Info("backoff task", zap.Duration("nextInterval", nextInterval)) + } + + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case task := <-b.pendingChan: + b.pendings = append(b.pendings, task) + case task := <-b.backoffChan: + // task is backoff, push it into backoff queue to make a delay retry. + b.backoffs.Push(task) + case <-nextBackOff: + // backoff is done, move all the backoff done task into pending to retry. + for b.backoffs.Len() > 0 && b.backoffs.Peek().NextInterval() < time.Millisecond { + b.pendings = append(b.pendings, b.backoffs.Pop()) + } + case workerChan <- nextTask: + // The task is sent to worker, remove it from pending list. + b.pendings = b.pendings[1:] + } + } +} + +func (b *broadcasterImpl) worker(no int) { + defer func() { + b.logger.Info("broadcaster worker exit", zap.Int("no", no)) + }() + + for { + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case task := <-b.workerChan: + if err := task.Poll(b.backgroundTaskNotifier.Context(), b.appendOperator); err != nil { + // If the task is not done, repush it into pendings and retry infinitely. + select { + case <-b.backgroundTaskNotifier.Context().Done(): + return + case b.backoffChan <- task: + } + } + } + } +} diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_test.go b/internal/streamingcoord/server/broadcaster/broadcaster_test.go new file mode 100644 index 0000000000000..624535f1c8755 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/broadcaster_test.go @@ -0,0 +1,142 @@ +package broadcaster + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/atomic" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster" + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + internaltypes "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +func TestBroadcaster(t *testing.T) { + meta := mock_metastore.NewMockStreamingCoordCataLog(t) + meta.EXPECT().ListBroadcastTask(mock.Anything). + RunAndReturn(func(ctx context.Context) ([]*streamingpb.BroadcastTask, error) { + return []*streamingpb.BroadcastTask{ + createNewBroadcastTask(1, []string{"v1"}), + createNewBroadcastTask(2, []string{"v1", "v2"}), + createNewBroadcastTask(3, []string{"v1", "v2", "v3"}), + }, nil + }).Times(1) + done := atomic.NewInt64(0) + meta.EXPECT().SaveBroadcastTask(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, bt *streamingpb.BroadcastTask) error { + // may failure + if rand.Int31n(10) < 5 { + return errors.New("save task failed") + } + if bt.State == streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE { + done.Inc() + } + return nil + }) + rc := idalloc.NewMockRootCoordClient(t) + f := syncutil.NewFuture[internaltypes.RootCoordClient]() + f.Set(rc) + resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptRootCoordClient(f)) + + operator, appended := createOpeartor(t) + bc, err := RecoverBroadcaster(context.Background(), operator) + assert.NoError(t, err) + assert.NotNil(t, bc) + assert.Eventually(t, func() bool { + return appended.Load() == 6 && done.Load() == 3 + }, 10*time.Second, 10*time.Millisecond) + + var result *types.BroadcastAppendResult + for { + var err error + result, err = bc.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"})) + if err == nil { + break + } + } + assert.Equal(t, int(appended.Load()), 9) + assert.Equal(t, len(result.AppendResults), 3) + + assert.Eventually(t, func() bool { + return done.Load() == 4 + }, 10*time.Second, 10*time.Millisecond) + + // TODO: error path. + bc.Close() + + result, err = bc.Broadcast(context.Background(), createNewBroadcastMsg([]string{"v1", "v2", "v3"})) + assert.Error(t, err) + assert.Nil(t, result) +} + +func createOpeartor(t *testing.T) (AppendOperator, *atomic.Int64) { + id := atomic.NewInt64(1) + appended := atomic.NewInt64(0) + operator := mock_broadcaster.NewMockAppendOperator(t) + f := func(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { + resps := streaming.AppendResponses{ + Responses: make([]streaming.AppendResponse, len(msgs)), + } + for idx := range msgs { + newID := walimplstest.NewTestMessageID(id.Inc()) + if rand.Int31n(10) < 5 { + resps.Responses[idx] = streaming.AppendResponse{ + Error: errors.New("append failed"), + } + continue + } + resps.Responses[idx] = streaming.AppendResponse{ + AppendResult: &types.AppendResult{ + MessageID: newID, + TimeTick: uint64(time.Now().UnixMilli()), + }, + Error: nil, + } + appended.Inc() + } + return resps + } + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything).RunAndReturn(f) + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) + operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) + return operator, appended +} + +func createNewBroadcastMsg(vchannels []string) message.BroadcastMutableMessage { + msg, err := message.NewDropCollectionMessageBuilderV1(). + WithHeader(&messagespb.DropCollectionMessageHeader{}). + WithBody(&msgpb.DropCollectionRequest{}). + WithBroadcast(vchannels). + BuildBroadcast() + if err != nil { + panic(err) + } + return msg +} + +func createNewBroadcastTask(taskID int64, vchannels []string) *streamingpb.BroadcastTask { + msg := createNewBroadcastMsg(vchannels) + return &streamingpb.BroadcastTask{ + TaskId: taskID, + Message: &messagespb.Message{ + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + }, + State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, + } +} diff --git a/internal/streamingcoord/server/broadcaster/task.go b/internal/streamingcoord/server/broadcaster/task.go new file mode 100644 index 0000000000000..52a2b0e77d0c6 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/task.go @@ -0,0 +1,126 @@ +package broadcaster + +import ( + "context" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var errBroadcastTaskIsNotDone = errors.New("broadcast task is not done") + +// newTask creates a new task +func newTask(task *streamingpb.BroadcastTask, logger *log.MLogger) *broadcastTask { + bt := message.NewBroadcastMutableMessage(task.Message.Payload, task.Message.Properties) + msgs := bt.SplitIntoMutableMessage() + return &broadcastTask{ + logger: logger.With(zap.Int64("taskID", task.TaskId), zap.Int("broadcastTotal", len(msgs))), + task: task, + pendingMessages: msgs, + appendResult: make(map[string]*types.AppendResult, len(msgs)), + future: syncutil.NewFuture[*types.BroadcastAppendResult](), + BackoffWithInstant: typeutil.NewBackoffWithInstant(typeutil.BackoffTimerConfig{ + Default: 10 * time.Second, + Backoff: typeutil.BackoffConfig{ + InitialInterval: 10 * time.Millisecond, + Multiplier: 2.0, + MaxInterval: 10 * time.Second, + }, + }), + } +} + +// broadcastTask is the task for broadcasting messages. +type broadcastTask struct { + logger *log.MLogger + task *streamingpb.BroadcastTask + pendingMessages []message.MutableMessage + appendResult map[string]*types.AppendResult + future *syncutil.Future[*types.BroadcastAppendResult] + *typeutil.BackoffWithInstant +} + +// Poll polls the task, return nil if the task is done, otherwise not done. +// Poll can be repeated called until the task is done. +func (b *broadcastTask) Poll(ctx context.Context, operator AppendOperator) error { + if len(b.pendingMessages) > 0 { + b.logger.Debug("broadcast task is polling to make sent...", zap.Int("pendingMessages", len(b.pendingMessages))) + resps := operator.AppendMessages(ctx, b.pendingMessages...) + newPendings := make([]message.MutableMessage, 0) + for idx, resp := range resps.Responses { + if resp.Error != nil { + newPendings = append(newPendings, b.pendingMessages[idx]) + continue + } + b.appendResult[b.pendingMessages[idx].VChannel()] = resp.AppendResult + } + b.pendingMessages = newPendings + if len(newPendings) == 0 { + b.future.Set(&types.BroadcastAppendResult{AppendResults: b.appendResult}) + } + b.logger.Info("broadcast task make a new broadcast done", zap.Int("pendingMessages", len(b.pendingMessages))) + } + if len(b.pendingMessages) == 0 { + // There's no more pending message, mark the task as done. + b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE + if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.task); err != nil { + b.logger.Warn("save broadcast task failed", zap.Error(err)) + b.UpdateInstantWithNextBackOff() + return err + } + return nil + } + b.UpdateInstantWithNextBackOff() + return errBroadcastTaskIsNotDone +} + +// BlockUntilTaskDone blocks until the task is done. +func (b *broadcastTask) BlockUntilTaskDone(ctx context.Context) (*types.BroadcastAppendResult, error) { + return b.future.GetWithContext(ctx) +} + +type broadcastTaskArray []*broadcastTask + +// Len returns the length of the heap. +func (h broadcastTaskArray) Len() int { + return len(h) +} + +// Less returns true if the element at index i is less than the element at index j. +func (h broadcastTaskArray) Less(i, j int) bool { + return h[i].NextInstant().Before(h[j].NextInstant()) +} + +// Swap swaps the elements at indexes i and j. +func (h broadcastTaskArray) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +// Push pushes the last one at len. +func (h *broadcastTaskArray) Push(x interface{}) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(*broadcastTask)) +} + +// Pop pop the last one at len. +func (h *broadcastTaskArray) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// Peek returns the element at the top of the heap. +// Panics if the heap is empty. +func (h *broadcastTaskArray) Peek() interface{} { + return (*h)[0] +} diff --git a/internal/streamingcoord/server/builder.go b/internal/streamingcoord/server/builder.go index 4d2215b6df638..dcbb5eeb4c0c7 100644 --- a/internal/streamingcoord/server/builder.go +++ b/internal/streamingcoord/server/builder.go @@ -5,6 +5,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/kv/streamingcoord" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/types" @@ -52,10 +53,13 @@ func (s *ServerBuilder) Build() *Server { resource.OptRootCoordClient(s.rootCoordClient), ) balancer := syncutil.NewFuture[balancer.Balancer]() + broadcaster := syncutil.NewFuture[broadcaster.Broadcaster]() return &Server{ logger: resource.Resource().Logger().With(log.FieldComponent("server")), session: s.session, assignmentService: service.NewAssignmentService(balancer), + broadcastService: service.NewBroadcastService(broadcaster), balancer: balancer, + broadcaster: broadcaster, } } diff --git a/internal/streamingcoord/server/resource/resource.go b/internal/streamingcoord/server/resource/resource.go index 89b8dee5730c1..96a92e3727125 100644 --- a/internal/streamingcoord/server/resource/resource.go +++ b/internal/streamingcoord/server/resource/resource.go @@ -8,6 +8,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/streamingnode/client/manager" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -29,6 +30,7 @@ func OptETCD(etcd *clientv3.Client) optResourceInit { func OptRootCoordClient(rootCoordClient *syncutil.Future[types.RootCoordClient]) optResourceInit { return func(r *resourceImpl) { r.rootCoordClient = rootCoordClient + r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } } @@ -48,6 +50,7 @@ func Init(opts ...optResourceInit) { for _, opt := range opts { opt(newR) } + assertNotNil(newR.IDAllocator()) assertNotNil(newR.RootCoordClient()) assertNotNil(newR.ETCD()) assertNotNil(newR.StreamingCatalog()) @@ -64,6 +67,7 @@ func Resource() *resourceImpl { // resourceImpl is a basic resource dependency for streamingnode server. // All utility on it is concurrent-safe and singleton. type resourceImpl struct { + idAllocator idalloc.Allocator rootCoordClient *syncutil.Future[types.RootCoordClient] etcdClient *clientv3.Client streamingCatalog metastore.StreamingCoordCataLog @@ -76,6 +80,11 @@ func (r *resourceImpl) RootCoordClient() *syncutil.Future[types.RootCoordClient] return r.rootCoordClient } +// IDAllocator returns the IDAllocator client. +func (r *resourceImpl) IDAllocator() idalloc.Allocator { + return r.idAllocator +} + // StreamingCatalog returns the StreamingCatalog client. func (r *resourceImpl) StreamingCatalog() metastore.StreamingCoordCataLog { return r.streamingCatalog diff --git a/internal/streamingcoord/server/server.go b/internal/streamingcoord/server/server.go index 2b9e50f3c2be4..f465d1b4b6cfc 100644 --- a/internal/streamingcoord/server/server.go +++ b/internal/streamingcoord/server/server.go @@ -8,6 +8,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" _ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/streamingutil" @@ -27,9 +28,11 @@ type Server struct { // service level variables. assignmentService service.AssignmentService + broadcastService service.BroadcastService // basic component variables can be used at service level. - balancer *syncutil.Future[balancer.Balancer] + balancer *syncutil.Future[balancer.Balancer] + broadcaster *syncutil.Future[broadcaster.Broadcaster] } // Init initializes the streamingcoord server. @@ -46,8 +49,9 @@ func (s *Server) Start(ctx context.Context) (err error) { // initBasicComponent initialize all underlying dependency for streamingcoord. func (s *Server) initBasicComponent(ctx context.Context) (err error) { + futures := make([]*conc.Future[struct{}], 0) if streamingutil.IsStreamingServiceEnabled() { - fBalancer := conc.Go(func() (struct{}, error) { + futures = append(futures, conc.Go(func() (struct{}, error) { s.logger.Info("start recovery balancer...") // Read new incoming topics from configuration, and register it into balancer. newIncomingTopics := util.GetAllTopicsFromConfiguration() @@ -59,10 +63,22 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) { s.balancer.Set(balancer) s.logger.Info("recover balancer done") return struct{}{}, nil - }) - return conc.AwaitAll(fBalancer) + })) } - return nil + // The broadcaster of msgstream is implemented on current streamingcoord to reduce the development complexity. + // So we need to recover it. + futures = append(futures, conc.Go(func() (struct{}, error) { + s.logger.Info("start recovery broadcaster...") + broadcaster, err := broadcaster.RecoverBroadcaster(ctx, broadcaster.NewAppendOperator()) + if err != nil { + s.logger.Warn("recover broadcaster failed", zap.Error(err)) + return struct{}{}, err + } + s.broadcaster.Set(broadcaster) + s.logger.Info("recover broadcaster done") + return struct{}{}, nil + })) + return conc.AwaitAll(futures...) } // RegisterGRPCService register all grpc service to grpc server. @@ -70,6 +86,7 @@ func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) { if streamingutil.IsStreamingServiceEnabled() { streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService) } + streamingpb.RegisterStreamingCoordBroadcastServiceServer(grpcServer, s.broadcastService) } // Close closes the streamingcoord server. @@ -80,5 +97,11 @@ func (s *Server) Stop() { } else { s.logger.Info("balancer not ready, skip close") } + if s.broadcaster.Ready() { + s.logger.Info("start close broadcaster...") + s.broadcaster.Get().Close() + } else { + s.logger.Info("broadcaster not ready, skip close") + } s.logger.Info("streamingcoord server stopped") } diff --git a/internal/streamingcoord/server/service/broadcast.go b/internal/streamingcoord/server/service/broadcast.go new file mode 100644 index 0000000000000..6d192615e32d4 --- /dev/null +++ b/internal/streamingcoord/server/service/broadcast.go @@ -0,0 +1,44 @@ +package service + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +// BroadcastService is the interface of the broadcast service. +type BroadcastService interface { + streamingpb.StreamingCoordBroadcastServiceServer +} + +// NewBroadcastService creates a new broadcast service. +func NewBroadcastService(bc *syncutil.Future[broadcaster.Broadcaster]) BroadcastService { + return &broadcastServceImpl{ + broadcaster: bc, + } +} + +// broadcastServiceeeeImpl is the implementation of the broadcast service. +type broadcastServceImpl struct { + broadcaster *syncutil.Future[broadcaster.Broadcaster] +} + +// Broadcast broadcasts the message to all channels. +func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.BroadcastRequest) (*streamingpb.BroadcastResponse, error) { + broadcaster, err := s.broadcaster.GetWithContext(ctx) + if err != nil { + return nil, err + } + results, err := broadcaster.Broadcast(ctx, message.NewBroadcastMutableMessage(req.Message.Payload, req.Message.Properties)) + if err != nil { + return nil, err + } + protoResult := make(map[string]*streamingpb.ProduceMessageResponseResult, len(results.AppendResults)) + for vchannel, result := range results.AppendResults { + protoResult[vchannel] = result.IntoProto() + } + return &streamingpb.BroadcastResponse{Results: protoResult}, nil +} diff --git a/internal/streamingnode/server/builder.go b/internal/streamingnode/server/builder.go index cdf725df55d01..f35e76b233375 100644 --- a/internal/streamingnode/server/builder.go +++ b/internal/streamingnode/server/builder.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) // ServerBuilder is used to build a server. @@ -18,8 +19,8 @@ import ( type ServerBuilder struct { etcdClient *clientv3.Client grpcServer *grpc.Server - rc types.RootCoordClient - dc types.DataCoordClient + rc *syncutil.Future[types.RootCoordClient] + dc *syncutil.Future[types.DataCoordClient] session *sessionutil.Session kv kv.MetaKv chunkManager storage.ChunkManager @@ -49,13 +50,13 @@ func (b *ServerBuilder) WithGRPCServer(svr *grpc.Server) *ServerBuilder { } // WithRootCoordClient sets root coord client to the server builder. -func (b *ServerBuilder) WithRootCoordClient(rc types.RootCoordClient) *ServerBuilder { +func (b *ServerBuilder) WithRootCoordClient(rc *syncutil.Future[types.RootCoordClient]) *ServerBuilder { b.rc = rc return b } // WithDataCoordClient sets data coord client to the server builder. -func (b *ServerBuilder) WithDataCoordClient(dc types.DataCoordClient) *ServerBuilder { +func (b *ServerBuilder) WithDataCoordClient(dc *syncutil.Future[types.DataCoordClient]) *ServerBuilder { b.dc = dc return b } diff --git a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go index caa1979b3f343..cfe69d68d4e28 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go +++ b/internal/streamingnode/server/flusher/flusherimpl/channel_lifetime.go @@ -86,8 +86,17 @@ func (c *channelLifetime) Run() error { // Get recovery info from datacoord. ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - resp, err := resource.Resource().DataCoordClient(). - GetChannelRecoveryInfo(ctx, &datapb.GetChannelRecoveryInfoRequest{Vchannel: c.vchannel}) + + pipelineParams, err := c.f.getPipelineParams(ctx) + if err != nil { + return err + } + + dc, err := resource.Resource().DataCoordClient().GetWithContext(ctx) + if err != nil { + return errors.Wrap(err, "At Get DataCoordClient") + } + resp, err := dc.GetChannelRecoveryInfo(ctx, &datapb.GetChannelRecoveryInfoRequest{Vchannel: c.vchannel}) if err = merr.CheckRPCCall(resp, err); err != nil { return err } @@ -115,7 +124,7 @@ func (c *channelLifetime) Run() error { } // Build and add pipeline. - ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, c.f.pipelineParams, + ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, pipelineParams, // TODO fubang add the db properties &datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan(), func(t syncmgr.Task, err error) { if err != nil || t == nil { diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go index f87acd7353c5e..a5c417b64b212 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go @@ -18,7 +18,6 @@ package flusherimpl import ( "context" - "sync" "time" "github.com/cockroachdb/errors" @@ -35,55 +34,54 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) var _ flusher.Flusher = (*flusherImpl)(nil) type flusherImpl struct { - broker broker.Broker - fgMgr pipeline.FlowgraphManager - syncMgr syncmgr.SyncManager - wbMgr writebuffer.BufferManager - cpUpdater *util.ChannelCheckpointUpdater + fgMgr pipeline.FlowgraphManager + wbMgr writebuffer.BufferManager + syncMgr syncmgr.SyncManager + cpUpdater *syncutil.Future[*util.ChannelCheckpointUpdater] + chunkManager storage.ChunkManager channelLifetimes *typeutil.ConcurrentMap[string, ChannelLifetime] - notifyCh chan struct{} - stopChan lifetime.SafeChan - stopWg sync.WaitGroup - pipelineParams *util.PipelineParams + notifyCh chan struct{} + notifier *syncutil.AsyncTaskNotifier[struct{}] } func NewFlusher(chunkManager storage.ChunkManager) flusher.Flusher { - params := getPipelineParams(chunkManager) - return newFlusherWithParam(params) -} - -func newFlusherWithParam(params *util.PipelineParams) flusher.Flusher { - fgMgr := pipeline.NewFlowgraphManager() + syncMgr := syncmgr.NewSyncManager(chunkManager) + wbMgr := writebuffer.NewManager(syncMgr) return &flusherImpl{ - broker: params.Broker, - fgMgr: fgMgr, - syncMgr: params.SyncMgr, - wbMgr: params.WriteBufferManager, - cpUpdater: params.CheckpointUpdater, + fgMgr: pipeline.NewFlowgraphManager(), + wbMgr: wbMgr, + syncMgr: syncMgr, + cpUpdater: syncutil.NewFuture[*util.ChannelCheckpointUpdater](), + chunkManager: chunkManager, channelLifetimes: typeutil.NewConcurrentMap[string, ChannelLifetime](), notifyCh: make(chan struct{}, 1), - stopChan: lifetime.NewSafeChan(), - pipelineParams: params, + notifier: syncutil.NewAsyncTaskNotifier[struct{}](), } } func (f *flusherImpl) RegisterPChannel(pchannel string, wal wal.WAL) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - resp, err := resource.Resource().RootCoordClient().GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ + rc, err := resource.Resource().RootCoordClient().GetWithContext(ctx) + if err != nil { + return errors.Wrap(err, "At Get RootCoordClient") + } + resp, err := rc.GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ Pchannel: pchannel, }) if err = merr.CheckRPCCall(resp, err); err != nil { @@ -126,11 +124,18 @@ func (f *flusherImpl) notify() { } func (f *flusherImpl) Start() { - f.stopWg.Add(1) f.wbMgr.Start() - go f.cpUpdater.Start() go func() { - defer f.stopWg.Done() + defer f.notifier.Finish(struct{}{}) + dc, err := resource.Resource().DataCoordClient().GetWithContext(f.notifier.Context()) + if err != nil { + return + } + broker := broker.NewCoordBroker(dc, paramtable.GetNodeID()) + cpUpdater := util.NewChannelCheckpointUpdater(broker) + go cpUpdater.Start() + f.cpUpdater.Set(cpUpdater) + backoff := typeutil.NewBackoffTimer(typeutil.BackoffTimerConfig{ Default: 5 * time.Second, Backoff: typeutil.BackoffConfig{ @@ -143,7 +148,7 @@ func (f *flusherImpl) Start() { var nextTimer <-chan time.Time for { select { - case <-f.stopChan.CloseCh(): + case <-f.notifier.Context().Done(): log.Info("flusher exited") return case <-f.notifyCh: @@ -190,13 +195,37 @@ func (f *flusherImpl) handle(backoff *typeutil.BackoffTimer) <-chan time.Time { } func (f *flusherImpl) Stop() { - f.stopChan.Close() - f.stopWg.Wait() + f.notifier.Cancel() + f.notifier.BlockUntilFinish() f.channelLifetimes.Range(func(vchannel string, lifetime ChannelLifetime) bool { lifetime.Cancel() return true }) f.fgMgr.ClearFlowgraphs() f.wbMgr.Stop() - f.cpUpdater.Close() + if f.cpUpdater.Ready() { + f.cpUpdater.Get().Close() + } +} + +func (f *flusherImpl) getPipelineParams(ctx context.Context) (*util.PipelineParams, error) { + dc, err := resource.Resource().DataCoordClient().GetWithContext(ctx) + if err != nil { + return nil, err + } + + cpUpdater, err := f.cpUpdater.GetWithContext(ctx) + if err != nil { + return nil, err + } + return &util.PipelineParams{ + Ctx: context.Background(), + Broker: broker.NewCoordBroker(dc, paramtable.GetNodeID()), + SyncMgr: f.syncMgr, + ChunkManager: f.chunkManager, + WriteBufferManager: f.wbMgr, + CheckpointUpdater: cpUpdater, + Allocator: idalloc.NewMAllocator(resource.Resource().IDAllocator()), + MsgHandler: newMsgHandler(f.wbMgr), + }, nil } diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go index aef723e7a59f6..f4f0116231962 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go @@ -30,8 +30,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" - "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -39,9 +37,11 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func init() { @@ -106,22 +106,8 @@ func newMockWAL(t *testing.T, vchannels []string, maybe bool) *mock_wal.MockWAL } func newTestFlusher(t *testing.T, maybe bool) flusher.Flusher { - wbMgr := writebuffer.NewMockBufferManager(t) - register := wbMgr.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - removeChannel := wbMgr.EXPECT().RemoveChannel(mock.Anything).Return() - start := wbMgr.EXPECT().Start().Return() - stop := wbMgr.EXPECT().Stop().Return() - if maybe { - register.Maybe() - removeChannel.Maybe() - start.Maybe() - stop.Maybe() - } m := mocks.NewChunkManager(t) - params := getPipelineParams(m) - params.SyncMgr = syncmgr.NewMockSyncManager(t) - params.WriteBufferManager = wbMgr - return newFlusherWithParam(params) + return NewFlusher(m) } func TestFlusher_RegisterPChannel(t *testing.T) { @@ -146,10 +132,16 @@ func TestFlusher_RegisterPChannel(t *testing.T) { rootcoord.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything). Return(&rootcoordpb.GetPChannelInfoResponse{Collections: collectionsInfo}, nil) datacoord := newMockDatacoord(t, maybe) + + fDatacoord := syncutil.NewFuture[types.DataCoordClient]() + fDatacoord.Set(datacoord) + + fRootcoord := syncutil.NewFuture[types.RootCoordClient]() + fRootcoord.Set(rootcoord) resource.InitForTest( t, - resource.OptRootCoordClient(rootcoord), - resource.OptDataCoordClient(datacoord), + resource.OptRootCoordClient(fRootcoord), + resource.OptDataCoordClient(fDatacoord), ) f := newTestFlusher(t, maybe) @@ -182,9 +174,11 @@ func TestFlusher_RegisterVChannel(t *testing.T) { } datacoord := newMockDatacoord(t, maybe) + fDatacoord := syncutil.NewFuture[types.DataCoordClient]() + fDatacoord.Set(datacoord) resource.InitForTest( t, - resource.OptDataCoordClient(datacoord), + resource.OptDataCoordClient(fDatacoord), ) f := newTestFlusher(t, maybe) @@ -220,9 +214,11 @@ func TestFlusher_Concurrency(t *testing.T) { } datacoord := newMockDatacoord(t, maybe) + fDatacoord := syncutil.NewFuture[types.DataCoordClient]() + fDatacoord.Set(datacoord) resource.InitForTest( t, - resource.OptDataCoordClient(datacoord), + resource.OptDataCoordClient(fDatacoord), ) f := newTestFlusher(t, maybe) diff --git a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go b/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go deleted file mode 100644 index 79751dff73444..0000000000000 --- a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flusherimpl - -import ( - "context" - - "github.com/milvus-io/milvus/internal/flushcommon/broker" - "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" - "github.com/milvus-io/milvus/internal/flushcommon/util" - "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" - "github.com/milvus-io/milvus/pkg/util/paramtable" -) - -// getPipelineParams initializes the pipeline parameters. -func getPipelineParams(chunkManager storage.ChunkManager) *util.PipelineParams { - var ( - rsc = resource.Resource() - syncMgr = syncmgr.NewSyncManager(chunkManager) - wbMgr = writebuffer.NewManager(syncMgr) - coordBroker = broker.NewCoordBroker(rsc.DataCoordClient(), paramtable.GetNodeID()) - cpUpdater = util.NewChannelCheckpointUpdater(coordBroker) - ) - return &util.PipelineParams{ - Ctx: context.Background(), - Broker: coordBroker, - SyncMgr: syncMgr, - ChunkManager: chunkManager, - WriteBufferManager: wbMgr, - CheckpointUpdater: cpUpdater, - Allocator: idalloc.NewMAllocator(rsc.IDAllocator()), - MsgHandler: newMsgHandler(wbMgr), - } -} diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index 23ff6316052b9..5a73fd0d0a8ea 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -8,10 +8,11 @@ import ( "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) var r = &resourceImpl{} // singleton resource instance @@ -41,7 +42,7 @@ func OptChunkManager(chunkManager storage.ChunkManager) optResourceInit { } // OptRootCoordClient provides the root coordinator client to the resource. -func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit { +func OptRootCoordClient(rootCoordClient *syncutil.Future[types.RootCoordClient]) optResourceInit { return func(r *resourceImpl) { r.rootCoordClient = rootCoordClient r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) @@ -50,7 +51,7 @@ func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit { } // OptDataCoordClient provides the data coordinator client to the resource. -func OptDataCoordClient(dataCoordClient types.DataCoordClient) optResourceInit { +func OptDataCoordClient(dataCoordClient *syncutil.Future[types.DataCoordClient]) optResourceInit { return func(r *resourceImpl) { r.dataCoordClient = dataCoordClient } @@ -96,8 +97,8 @@ type resourceImpl struct { idAllocator idalloc.Allocator etcdClient *clientv3.Client chunkManager storage.ChunkManager - rootCoordClient types.RootCoordClient - dataCoordClient types.DataCoordClient + rootCoordClient *syncutil.Future[types.RootCoordClient] + dataCoordClient *syncutil.Future[types.DataCoordClient] streamingNodeCatalog metastore.StreamingNodeCataLog segmentAssignStatsManager *stats.StatsManager timeTickInspector tinspector.TimeTickSyncInspector @@ -129,12 +130,12 @@ func (r *resourceImpl) ChunkManager() storage.ChunkManager { } // RootCoordClient returns the root coordinator client. -func (r *resourceImpl) RootCoordClient() types.RootCoordClient { +func (r *resourceImpl) RootCoordClient() *syncutil.Future[types.RootCoordClient] { return r.rootCoordClient } // DataCoordClient returns the data coordinator client. -func (r *resourceImpl) DataCoordClient() types.DataCoordClient { +func (r *resourceImpl) DataCoordClient() *syncutil.Future[types.DataCoordClient] { return r.dataCoordClient } diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index 1d8d4f976f784..8c219d86ff0c8 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -6,9 +6,10 @@ import ( "github.com/stretchr/testify/assert" clientv3 "go.etcd.io/etcd/client/v3" - "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestApply(t *testing.T) { @@ -16,7 +17,7 @@ func TestApply(t *testing.T) { Apply() Apply(OptETCD(&clientv3.Client{})) - Apply(OptRootCoordClient(mocks.NewMockRootCoordClient(t))) + Apply(OptRootCoordClient(syncutil.NewFuture[types.RootCoordClient]())) assert.Panics(t, func() { Done() @@ -24,8 +25,8 @@ func TestApply(t *testing.T) { Apply( OptETCD(&clientv3.Client{}), - OptRootCoordClient(mocks.NewMockRootCoordClient(t)), - OptDataCoordClient(mocks.NewMockDataCoordClient(t)), + OptRootCoordClient(syncutil.NewFuture[types.RootCoordClient]()), + OptDataCoordClient(syncutil.NewFuture[types.DataCoordClient]()), OptStreamingNodeCatalog(mock_metastore.NewMockStreamingNodeCataLog(t)), ) Done() diff --git a/internal/streamingnode/server/resource/test_utility.go b/internal/streamingnode/server/resource/test_utility.go index bad9e0f4bf1de..68e670dde7ca1 100644 --- a/internal/streamingnode/server/resource/test_utility.go +++ b/internal/streamingnode/server/resource/test_utility.go @@ -6,9 +6,11 @@ package resource import ( "testing" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) // InitForTest initializes the singleton of resources for test. @@ -21,7 +23,9 @@ func InitForTest(t *testing.T, opts ...optResourceInit) { r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } else { - r.rootCoordClient = idalloc.NewMockRootCoordClient(t) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(idalloc.NewMockRootCoordClient(t)) + r.rootCoordClient = f r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go index 366e79534d0d1..4f06b76dce620 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -12,7 +12,6 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -217,20 +216,9 @@ func (p *ProduceServer) sendProduceResult(reqID int64, appendResult *wal.AppendR } if err != nil { p.logger.Warn("append message to wal failed", zap.Int64("requestID", reqID), zap.Error(err)) - resp.Response = &streamingpb.ProduceMessageResponse_Error{ - Error: status.AsStreamingError(err).AsPBError(), - } + resp.Response = &streamingpb.ProduceMessageResponse_Error{Error: status.AsStreamingError(err).AsPBError()} } else { - resp.Response = &streamingpb.ProduceMessageResponse_Result{ - Result: &streamingpb.ProduceMessageResponseResult{ - Id: &messagespb.MessageID{ - Id: appendResult.MessageID.Marshal(), - }, - Timetick: appendResult.TimeTick, - TxnContext: appendResult.TxnCtx.IntoProto(), - Extra: appendResult.Extra, - }, - } + resp.Response = &streamingpb.ProduceMessageResponse_Result{Result: appendResult.IntoProto()} } // If server context is canceled, it means the stream has been closed. diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index f9f1fb80be165..8a222b04b35c0 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -21,13 +21,15 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/registry" + internaltypes "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) const testVChannel = "v1" @@ -53,8 +55,15 @@ func initResourceForTest(t *testing.T) { rc := idalloc.NewMockRootCoordClient(t) rc.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{}, nil) + fRootCoordClient := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootCoordClient.Set(rc) + dc := mocks.NewMockDataCoordClient(t) dc.EXPECT().AllocSegment(mock.Anything, mock.Anything).Return(&datapb.AllocSegmentResponse{}, nil) + + fDataCoordClient := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDataCoordClient.Set(dc) + catalog := mock_metastore.NewMockStreamingNodeCataLog(t) catalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return(nil, nil) catalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -67,8 +76,8 @@ func initResourceForTest(t *testing.T) { resource.InitForTest( t, - resource.OptRootCoordClient(rc), - resource.OptDataCoordClient(dc), + resource.OptRootCoordClient(fRootCoordClient), + resource.OptDataCoordClient(fDataCoordClient), resource.OptFlusher(flusher), resource.OptStreamingNodeCatalog(catalog), ) diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go index def30b9575115..bce92f57960d6 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go @@ -225,7 +225,11 @@ func (m *partitionSegmentManager) allocNewGrowingSegment(ctx context.Context) (* // Transfer the pending segment into growing state. // Alloc the growing segment at datacoord first. - resp, err := resource.Resource().DataCoordClient().AllocSegment(ctx, &datapb.AllocSegmentRequest{ + dc, err := resource.Resource().DataCoordClient().GetWithContext(ctx) + if err != nil { + return nil, err + } + resp, err := dc.AllocSegment(ctx, &datapb.AllocSegmentRequest{ CollectionId: pendingSegment.GetCollectionID(), PartitionId: pendingSegment.GetPartitionID(), SegmentId: pendingSegment.GetSegmentID(), diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go index fe30a7e2fbde2..e942ffae35c55 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go @@ -32,7 +32,11 @@ func RecoverPChannelSegmentAllocManager( return nil, errors.Wrap(err, "failed to list segment assignment from catalog") } // get collection and parition info from rootcoord. - resp, err := resource.Resource().RootCoordClient().GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ + rc, err := resource.Resource().RootCoordClient().GetWithContext(ctx) + if err != nil { + return nil, err + } + resp, err := rc.GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ Pchannel: pchannel.Name, }) if err := merr.CheckRPCCall(resp, err); err != nil { diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go index 33597cce87c25..4497551c2bc58 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -15,11 +15,12 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" + internaltypes "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" @@ -311,6 +312,8 @@ func initializeTestState(t *testing.T) { Status: merr.Success(), }, nil }) + fDataCoordClient := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDataCoordClient.Set(dataCoordClient) rootCoordClient := idalloc.NewMockRootCoordClient(t) rootCoordClient.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{ @@ -325,11 +328,13 @@ func initializeTestState(t *testing.T) { }, }, }, nil) + fRootCoordClient := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootCoordClient.Set(rootCoordClient) resource.InitForTest(t, resource.OptStreamingNodeCatalog(streamingNodeCatalog), - resource.OptDataCoordClient(dataCoordClient), - resource.OptRootCoordClient(rootCoordClient), + resource.OptDataCoordClient(fDataCoordClient), + resource.OptRootCoordClient(fRootCoordClient), ) streamingNodeCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return( []*streamingpb.SegmentAssignmentMeta{ diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go index 0803931c3b909..0ba11fd88b499 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go @@ -17,9 +17,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/metricsutil" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestAck(t *testing.T) { @@ -43,7 +45,9 @@ func TestAck(t *testing.T) { }, nil }, ) - resource.InitForTest(t, resource.OptRootCoordClient(rc)) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(rc) + resource.InitForTest(t, resource.OptRootCoordClient(f)) ackManager := NewAckManager(0, nil, metricsutil.NewTimeTickMetrics("test")) @@ -160,7 +164,9 @@ func TestAckManager(t *testing.T) { }, nil }, ) - resource.InitForTest(t, resource.OptRootCoordClient(rc)) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(rc) + resource.InitForTest(t, resource.OptRootCoordClient(f)) ackManager := NewAckManager(0, walimplstest.NewTestMessageID(0), metricsutil.NewTimeTickMetrics("test")) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go index 2ed586859fcb1..a705663ad8a77 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_message.go @@ -21,7 +21,7 @@ func NewTimeTickMsg(ts uint64, lastConfirmedMessageID message.MessageID, sourceI commonpbutil.WithSourceID(sourceID), ), }). - WithBroadcast(). + WithAllVChannel(). BuildMutable() if err != nil { return nil, err diff --git a/internal/streamingnode/server/walmanager/manager_impl_test.go b/internal/streamingnode/server/walmanager/manager_impl_test.go index 35b269cc04a85..cdaa931e3c51d 100644 --- a/internal/streamingnode/server/walmanager/manager_impl_test.go +++ b/internal/streamingnode/server/walmanager/manager_impl_test.go @@ -12,10 +12,12 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestMain(m *testing.M) { @@ -25,7 +27,11 @@ func TestMain(m *testing.M) { func TestManager(t *testing.T) { rootcoord := mocks.NewMockRootCoordClient(t) + fRootcoord := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootcoord.Set(rootcoord) datacoord := mocks.NewMockDataCoordClient(t) + fDatacoord := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDatacoord.Set(datacoord) flusher := mock_flusher.NewMockFlusher(t) flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil) @@ -33,8 +39,8 @@ func TestManager(t *testing.T) { resource.InitForTest( t, resource.OptFlusher(flusher), - resource.OptRootCoordClient(rootcoord), - resource.OptDataCoordClient(datacoord), + resource.OptRootCoordClient(fRootcoord), + resource.OptDataCoordClient(fDatacoord), ) opener := mock_wal.NewMockOpener(t) diff --git a/internal/streamingnode/server/walmanager/wal_lifetime_test.go b/internal/streamingnode/server/walmanager/wal_lifetime_test.go index d34bfe4f88896..a14464df8b594 100644 --- a/internal/streamingnode/server/walmanager/wal_lifetime_test.go +++ b/internal/streamingnode/server/walmanager/wal_lifetime_test.go @@ -12,14 +12,20 @@ import ( "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestWALLifetime(t *testing.T) { channel := "test" rootcoord := mocks.NewMockRootCoordClient(t) + fRootcoord := syncutil.NewFuture[internaltypes.RootCoordClient]() + fRootcoord.Set(rootcoord) datacoord := mocks.NewMockDataCoordClient(t) + fDatacoord := syncutil.NewFuture[internaltypes.DataCoordClient]() + fDatacoord.Set(datacoord) flusher := mock_flusher.NewMockFlusher(t) flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil) @@ -28,8 +34,8 @@ func TestWALLifetime(t *testing.T) { resource.InitForTest( t, resource.OptFlusher(flusher), - resource.OptRootCoordClient(rootcoord), - resource.OptDataCoordClient(datacoord), + resource.OptRootCoordClient(fRootcoord), + resource.OptDataCoordClient(fDatacoord), ) opener := mock_wal.NewMockOpener(t) diff --git a/internal/streamingnode/server/resource/idalloc/allocator.go b/internal/util/idalloc/allocator.go similarity index 94% rename from internal/streamingnode/server/resource/idalloc/allocator.go rename to internal/util/idalloc/allocator.go index 3e8b7bdb59d23..f614d6f5ec3d6 100644 --- a/internal/streamingnode/server/resource/idalloc/allocator.go +++ b/internal/util/idalloc/allocator.go @@ -22,6 +22,7 @@ import ( "time" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) // batchAllocateSize is the size of batch allocate from remote allocator. @@ -30,7 +31,7 @@ const batchAllocateSize = 1000 var _ Allocator = (*allocatorImpl)(nil) // NewTSOAllocator creates a new allocator. -func NewTSOAllocator(rc types.RootCoordClient) Allocator { +func NewTSOAllocator(rc *syncutil.Future[types.RootCoordClient]) Allocator { return &allocatorImpl{ mu: sync.Mutex{}, remoteAllocator: newTSOAllocator(rc), @@ -39,7 +40,7 @@ func NewTSOAllocator(rc types.RootCoordClient) Allocator { } // NewIDAllocator creates a new allocator. -func NewIDAllocator(rc types.RootCoordClient) Allocator { +func NewIDAllocator(rc *syncutil.Future[types.RootCoordClient]) Allocator { return &allocatorImpl{ mu: sync.Mutex{}, remoteAllocator: newIDAllocator(rc), diff --git a/internal/streamingnode/server/resource/idalloc/allocator_test.go b/internal/util/idalloc/allocator_test.go similarity index 81% rename from internal/streamingnode/server/resource/idalloc/allocator_test.go rename to internal/util/idalloc/allocator_test.go index c4db2e520a578..26eb9e90c2b1a 100644 --- a/internal/streamingnode/server/resource/idalloc/allocator_test.go +++ b/internal/util/idalloc/allocator_test.go @@ -11,7 +11,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestTimestampAllocator(t *testing.T) { @@ -19,7 +21,10 @@ func TestTimestampAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) - allocator := NewTSOAllocator(client) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator := NewTSOAllocator(f) for i := 0; i < 5000; i++ { ts, err := allocator.Allocate(context.Background()) @@ -46,7 +51,10 @@ func TestTimestampAllocator(t *testing.T) { }, nil }, ) - allocator = NewTSOAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = NewTSOAllocator(f) _, err := allocator.Allocate(context.Background()) assert.Error(t, err) } diff --git a/internal/streamingnode/server/resource/idalloc/basic_allocator.go b/internal/util/idalloc/basic_allocator.go similarity index 83% rename from internal/streamingnode/server/resource/idalloc/basic_allocator.go rename to internal/util/idalloc/basic_allocator.go index 8e0ad90e63d1c..8b9e220cc410a 100644 --- a/internal/streamingnode/server/resource/idalloc/basic_allocator.go +++ b/internal/util/idalloc/basic_allocator.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) var errExhausted = errors.New("exhausted") @@ -56,12 +57,12 @@ func (a *localAllocator) exhausted() { // tsoAllocator allocate timestamp from remote root coordinator. type tsoAllocator struct { - rc types.RootCoordClient + rc *syncutil.Future[types.RootCoordClient] nodeID int64 } // newTSOAllocator creates a new remote allocator. -func newTSOAllocator(rc types.RootCoordClient) *tsoAllocator { +func newTSOAllocator(rc *syncutil.Future[types.RootCoordClient]) *tsoAllocator { a := &tsoAllocator{ nodeID: paramtable.GetNodeID(), rc: rc, @@ -80,8 +81,12 @@ func (ta *tsoAllocator) batchAllocate(ctx context.Context, count uint32) (uint64 ), Count: count, } + rc, err := ta.rc.GetWithContext(ctx) + if err != nil { + return 0, 0, fmt.Errorf("get root coordinator client timeout: %w", err) + } - resp, err := ta.rc.AllocTimestamp(ctx, req) + resp, err := rc.AllocTimestamp(ctx, req) if err != nil { return 0, 0, fmt.Errorf("syncTimestamp Failed:%w", err) } @@ -96,12 +101,12 @@ func (ta *tsoAllocator) batchAllocate(ctx context.Context, count uint32) (uint64 // idAllocator allocate timestamp from remote root coordinator. type idAllocator struct { - rc types.RootCoordClient + rc *syncutil.Future[types.RootCoordClient] nodeID int64 } // newIDAllocator creates a new remote allocator. -func newIDAllocator(rc types.RootCoordClient) *idAllocator { +func newIDAllocator(rc *syncutil.Future[types.RootCoordClient]) *idAllocator { a := &idAllocator{ nodeID: paramtable.GetNodeID(), rc: rc, @@ -120,8 +125,12 @@ func (ta *idAllocator) batchAllocate(ctx context.Context, count uint32) (uint64, ), Count: count, } + rc, err := ta.rc.GetWithContext(ctx) + if err != nil { + return 0, 0, fmt.Errorf("get root coordinator client timeout: %w", err) + } - resp, err := ta.rc.AllocID(ctx, req) + resp, err := rc.AllocID(ctx, req) if err != nil { return 0, 0, fmt.Errorf("AllocID Failed:%w", err) } diff --git a/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go b/internal/util/idalloc/basic_allocator_test.go similarity index 84% rename from internal/streamingnode/server/resource/idalloc/basic_allocator_test.go rename to internal/util/idalloc/basic_allocator_test.go index 081832006f017..549f78cc00d8b 100644 --- a/internal/streamingnode/server/resource/idalloc/basic_allocator_test.go +++ b/internal/util/idalloc/basic_allocator_test.go @@ -13,7 +13,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" ) func TestLocalAllocator(t *testing.T) { @@ -63,8 +65,10 @@ func TestRemoteTSOAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) - allocator := newTSOAllocator(client) + allocator := newTSOAllocator(f) ts, count, err := allocator.batchAllocate(context.Background(), 100) assert.NoError(t, err) assert.NotZero(t, ts) @@ -77,7 +81,10 @@ func TestRemoteTSOAllocator(t *testing.T) { return nil, errors.New("test") }, ) - allocator = newTSOAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newTSOAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) @@ -91,7 +98,10 @@ func TestRemoteTSOAllocator(t *testing.T) { }, nil }, ) - allocator = newTSOAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newTSOAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) } @@ -101,8 +111,11 @@ func TestRemoteIDAllocator(t *testing.T) { paramtable.SetNodeID(1) client := NewMockRootCoordClient(t) + f := syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator := newIDAllocator(f) - allocator := newIDAllocator(client) ts, count, err := allocator.batchAllocate(context.Background(), 100) assert.NoError(t, err) assert.NotZero(t, ts) @@ -115,7 +128,10 @@ func TestRemoteIDAllocator(t *testing.T) { return nil, errors.New("test") }, ) - allocator = newIDAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newIDAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) @@ -129,7 +145,10 @@ func TestRemoteIDAllocator(t *testing.T) { }, nil }, ) - allocator = newIDAllocator(client) + f = syncutil.NewFuture[types.RootCoordClient]() + f.Set(client) + + allocator = newIDAllocator(f) _, _, err = allocator.batchAllocate(context.Background(), 100) assert.Error(t, err) } diff --git a/internal/streamingnode/server/resource/idalloc/mallocator.go b/internal/util/idalloc/mallocator.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/mallocator.go rename to internal/util/idalloc/mallocator.go diff --git a/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go b/internal/util/idalloc/test_mock_root_coord_client.go similarity index 100% rename from internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go rename to internal/util/idalloc/test_mock_root_coord_client.go diff --git a/pkg/streaming/proto/messages.proto b/pkg/streaming/proto/messages.proto index 62b84f98a2c80..4417aa461af9e 100644 --- a/pkg/streaming/proto/messages.proto +++ b/pkg/streaming/proto/messages.proto @@ -242,3 +242,8 @@ enum TxnState { // the transaction is rollbacked. TxnRollbacked = 6; } + +// VChannels is a layout to represent the virtual channels for broadcast. +message VChannels { + repeated string vchannels = 1; +} \ No newline at end of file diff --git a/pkg/streaming/proto/streaming.proto b/pkg/streaming/proto/streaming.proto index e4a6943ae2645..0a7debc9dad5c 100644 --- a/pkg/streaming/proto/streaming.proto +++ b/pkg/streaming/proto/streaming.proto @@ -60,18 +60,48 @@ message VersionPair { int64 local = 2; } +// BroadcastTaskState is the state of the broadcast task. +enum BroadcastTaskState { + BROADCAST_TASK_STATE_UNKNOWN = 0; // should never used. + BROADCAST_TASK_STATE_PENDING = 1; // task is pending. + BROADCAST_TASK_STATE_DONE = 2; // task is done, the message is broadcasted, and the persisted task can be cleared. +} + +// BroadcastTask is the task to broadcast the message. +message BroadcastTask { + int64 task_id = 1; // task id. + messages.Message message = 2; // message to be broadcast. + BroadcastTaskState state = 3; // state of the task. +} + // // Milvus Service // -service StreamingCoordStateService { +service StreamingNodeStateService { rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} } -service StreamingNodeStateService { - rpc GetComponentStates(milvus.GetComponentStatesRequest) - returns (milvus.ComponentStates) {} +// +// StreamingCoordBroadcastService +// + +// StreamingCoordBroadcastService is the broadcast service for streaming coord. +service StreamingCoordBroadcastService { + // Broadcast receives broadcast messages from other component and make sure that the message is broadcast to all wal. + // It performs an atomic broadcast to all wal, achieve eventual consistency. + rpc Broadcast(BroadcastRequest) returns (BroadcastResponse) {} +} + +// BroadcastRequest is the request of the Broadcast RPC. +message BroadcastRequest { + messages.Message message = 1; // message to be broadcast. +} + +// BroadcastResponse is the response of the Broadcast RPC. +message BroadcastResponse { + map results = 1; } // diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index 32bdad9db6482..0f941c6851bd4 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -7,16 +7,32 @@ import ( "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // NewMutableMessage creates a new mutable message. // !!! Only used at server side for streamingnode internal service, don't use it at client side. func NewMutableMessage(payload []byte, properties map[string]string) MutableMessage { - return &messageImpl{ + m := &messageImpl{ payload: payload, properties: properties, } + // make a assertion by vchannel function. + m.assertNotBroadcast() + return m +} + +// NewBroadcastMutableMessage creates a new broadcast mutable message. +// !!! Only used at server side for streamingcoord internal service, don't use it at client side. +func NewBroadcastMutableMessage(payload []byte, properties map[string]string) BroadcastMutableMessage { + m := &messageImpl{ + payload: payload, + properties: properties, + } + m.assertBroadcast() + return m } // NewImmutableMessage creates a new immutable message. @@ -82,10 +98,10 @@ func newMutableMessageBuilder[H proto.Message, B proto.Message](v Version) *muta // mutableMesasgeBuilder is the builder for message. type mutableMesasgeBuilder[H proto.Message, B proto.Message] struct { - header H - body B - properties propertiesImpl - broadcast bool + header H + body B + properties propertiesImpl + allVChannel bool } // WithMessageHeader creates a new builder with determined message type. @@ -102,16 +118,41 @@ func (b *mutableMesasgeBuilder[H, B]) WithBody(body B) *mutableMesasgeBuilder[H, // WithVChannel creates a new builder with virtual channel. func (b *mutableMesasgeBuilder[H, B]) WithVChannel(vchannel string) *mutableMesasgeBuilder[H, B] { - if b.broadcast { - panic("a broadcast message cannot hold vchannel") + if b.allVChannel { + panic("a all vchannel message cannot set up vchannel property") } b.WithProperty(messageVChannel, vchannel) return b } // WithBroadcast creates a new builder with broadcast property. -func (b *mutableMesasgeBuilder[H, B]) WithBroadcast() *mutableMesasgeBuilder[H, B] { - b.broadcast = true +func (b *mutableMesasgeBuilder[H, B]) WithBroadcast(vchannels []string) *mutableMesasgeBuilder[H, B] { + if len(vchannels) < 1 { + panic("broadcast message must have at least one vchannel") + } + if b.allVChannel { + panic("a all vchannel message cannot set up vchannel property") + } + if b.properties.Exist(messageVChannel) { + panic("a broadcast message cannot set up vchannel property") + } + deduplicated := typeutil.NewSet(vchannels...) + vcs, err := EncodeProto(&messagespb.VChannels{ + Vchannels: deduplicated.Collect(), + }) + if err != nil { + panic("failed to encode vchannels") + } + b.properties.Set(messageVChannels, vcs) + return b +} + +// WithAllVChannel creates a new builder with all vchannel property. +func (b *mutableMesasgeBuilder[H, B]) WithAllVChannel() *mutableMesasgeBuilder[H, B] { + if b.properties.Exist(messageVChannel) || b.properties.Exist(messageVChannels) { + panic("a vchannel or broadcast message cannot set up all vchannel property") + } + b.allVChannel = true return b } @@ -135,6 +176,34 @@ func (b *mutableMesasgeBuilder[H, B]) WithProperties(kvs map[string]string) *mut // Panic if not set payload and message type. // should only used at client side. func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { + if !b.allVChannel && !b.properties.Exist(messageVChannel) { + panic("a non broadcast message builder not ready for vchannel field") + } + + msg, err := b.build() + if err != nil { + return nil, err + } + return msg, nil +} + +// BuildBroadcast builds a broad mutable message. +// Panic if not set payload and message type. +// should only used at client side. +func (b *mutableMesasgeBuilder[H, B]) BuildBroadcast() (BroadcastMutableMessage, error) { + if !b.properties.Exist(messageVChannels) { + panic("a broadcast message builder not ready for vchannel field") + } + + msg, err := b.build() + if err != nil { + return nil, err + } + return msg, nil +} + +// build builds a message. +func (b *mutableMesasgeBuilder[H, B]) build() (*messageImpl, error) { // payload and header must be a pointer if reflect.ValueOf(b.header).IsNil() { panic("message builder not ready for header field") @@ -142,9 +211,6 @@ func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { if reflect.ValueOf(b.body).IsNil() { panic("message builder not ready for body field") } - if !b.broadcast && !b.properties.Exist(messageVChannel) { - panic("a non broadcast message builder not ready for vchannel field") - } // setup header. sp, err := EncodeProto(b.header) diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index 733ed568d8450..49a7361c82fcd 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -29,11 +29,6 @@ type BasicMessage interface { // Should be used with read-only promise. Properties() RProperties - // VChannel returns the virtual channel of current message. - // Available only when the message's version greater than 0. - // Return "" if message is broadcasted. - VChannel() string - // TimeTick returns the time tick of current message. // Available only when the message's version greater than 0. // Otherwise, it will panic. @@ -52,6 +47,11 @@ type BasicMessage interface { type MutableMessage interface { BasicMessage + // VChannel returns the virtual channel of current message. + // Available only when the message's version greater than 0. + // Return "" if message is can be seen by all vchannels on the pchannel. + VChannel() string + // WithBarrierTimeTick sets the barrier time tick of current message. // these time tick is used to promised the message will be sent after that time tick. // and the message which timetick is less than it will never concurrent append with it. @@ -82,6 +82,19 @@ type MutableMessage interface { IntoImmutableMessage(msgID MessageID) ImmutableMessage } +// BroadcastMutableMessage is the broadcast message interface. +// Indicated the message is broadcasted on various vchannels. +type BroadcastMutableMessage interface { + BasicMessage + + // BroadcastVChannels returns the target vchannels of the message broadcast. + // Those vchannels can be on multi pchannels. + BroadcastVChannels() []string + + // SplitIntoMutableMessage splits the broadcast message into multiple mutable messages. + SplitIntoMutableMessage() []MutableMessage +} + // ImmutableMessage is the read-only message interface. // Once a message is persistent by wal or temporary generated by wal, it will be immutable. type ImmutableMessage interface { @@ -90,6 +103,11 @@ type ImmutableMessage interface { // WALName returns the name of message related wal. WALName() string + // VChannel returns the virtual channel of current message. + // Available only when the message's version greater than 0. + // Return "" if message is can be seen by all vchannels on the pchannel. + VChannel() string + // MessageID returns the message id of current message. MessageID() MessageID diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index 41e9ac0379af2..7e4a4c0be2498 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -141,8 +141,11 @@ func (m *messageImpl) BarrierTimeTick() uint64 { } // VChannel returns the vchannel of current message. -// If the message is broadcasted, the vchannel will be empty. +// If the message is a all channel message, it will return "". +// If the message is a broadcast message, it will panic. func (m *messageImpl) VChannel() string { + m.assertNotBroadcast() + value, ok := m.properties.Get(messageVChannel) if !ok { return "" @@ -150,6 +153,60 @@ func (m *messageImpl) VChannel() string { return value } +// BroadcastVChannels returns the vchannels of current message that want to broadcast. +// If the message is not a broadcast message, it will panic. +func (m *messageImpl) BroadcastVChannels() []string { + m.assertBroadcast() + + value, _ := m.properties.Get(messageVChannels) + vcs := &messagespb.VChannels{} + if err := DecodeProto(value, vcs); err != nil { + panic("can not decode vchannels") + } + return vcs.Vchannels +} + +// SplitIntoMutableMessage splits the current broadcast message into multiple messages. +func (m *messageImpl) SplitIntoMutableMessage() []MutableMessage { + vchannels := m.BroadcastVChannels() + + vchannelExist := make(map[string]struct{}, len(vchannels)) + msgs := make([]MutableMessage, 0, len(vchannels)) + for _, vchannel := range vchannels { + newPayload := make([]byte, len(m.payload)) + copy(newPayload, m.payload) + + newProperties := make(propertiesImpl, len(m.properties)) + for key, val := range m.properties { + if key != messageVChannels { + newProperties.Set(key, val) + } + } + newProperties.Set(messageVChannel, vchannel) + if _, ok := vchannelExist[vchannel]; ok { + panic("there's a bug in the message codes, duplicate vchannel in broadcast message") + } + msgs = append(msgs, &messageImpl{ + payload: newPayload, + properties: newProperties, + }) + vchannelExist[vchannel] = struct{}{} + } + return msgs +} + +func (m *messageImpl) assertNotBroadcast() { + if m.properties.Exist(messageVChannels) { + panic("current message is a broadcast message") + } +} + +func (m *messageImpl) assertBroadcast() { + if !m.properties.Exist(messageVChannels) { + panic("current message is not a broadcast message") + } +} + type immutableMessageImpl struct { messageImpl id MessageID diff --git a/pkg/streaming/util/message/properties.go b/pkg/streaming/util/message/properties.go index 575c7d2146b80..3f0d120e32fd4 100644 --- a/pkg/streaming/util/message/properties.go +++ b/pkg/streaming/util/message/properties.go @@ -10,6 +10,7 @@ const ( messageLastConfirmed = "_lc" // message last confirmed message id. messageLastConfirmedIDSameWithMessageID = "_lcs" // message last confirmed message id is the same with message id. messageVChannel = "_vc" // message virtual channel. + messageVChannels = "_vcs" // message virtual channels for broadcast message. messageHeader = "_h" // specialized message header. messageTxnContext = "_tx" // transaction context. ) diff --git a/pkg/streaming/util/types/streaming_node.go b/pkg/streaming/util/types/streaming_node.go index 4c6a13e699d17..0cca5798e19d1 100644 --- a/pkg/streaming/util/types/streaming_node.go +++ b/pkg/streaming/util/types/streaming_node.go @@ -7,6 +7,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -88,6 +89,16 @@ func (n *StreamingNodeStatus) ErrorOfNode() error { return n.Err } +// BroadcastAppendResult is the result of broadcast append operation. +type BroadcastAppendResult struct { + AppendResults map[string]*AppendResult // make the channel name to the append result. +} + +// GetAppendResult returns the append result of the given channel. +func (r *BroadcastAppendResult) GetAppendResult(channelName string) *AppendResult { + return r.AppendResults[channelName] +} + // AppendResult is the result of append operation. type AppendResult struct { // MessageID is generated by underlying walimpls. @@ -112,3 +123,15 @@ func (r *AppendResult) GetExtra(m proto.Message) error { AllowPartial: true, }) } + +// IntoProto converts the append result to proto. +func (r *AppendResult) IntoProto() *streamingpb.ProduceMessageResponseResult { + return &streamingpb.ProduceMessageResponseResult{ + Id: &messagespb.MessageID{ + Id: r.MessageID.Marshal(), + }, + Timetick: r.TimeTick, + TxnContext: r.TxnCtx.IntoProto(), + Extra: r.Extra, + } +} diff --git a/pkg/util/contextutil/context_util.go b/pkg/util/contextutil/context_util.go index 8cf699b43079b..2bded437d1ec5 100644 --- a/pkg/util/contextutil/context_util.go +++ b/pkg/util/contextutil/context_util.go @@ -121,3 +121,15 @@ func WithDeadlineCause(parent context.Context, deadline time.Time, err error) (c cancel(context.Canceled) } } + +// MergeContext create a cancellation context that cancels when any of the given contexts are canceled. +func MergeContext(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancelCause(ctx1) + stop := context.AfterFunc(ctx2, func() { + cancel(context.Cause(ctx2)) + }) + return ctx, func() { + stop() + cancel(context.Canceled) + } +} diff --git a/pkg/util/retry/options.go b/pkg/util/retry/options.go index 80f00a9ffc8f9..852e4ec7d786e 100644 --- a/pkg/util/retry/options.go +++ b/pkg/util/retry/options.go @@ -31,6 +31,12 @@ func newDefaultConfig() *config { // Option is used to config the retry function. type Option func(*config) +func AttemptAlways() Option { + return func(c *config) { + c.attempts = 0 + } +} + // Attempts is used to config the max retry times. func Attempts(attempts uint) Option { return func(c *config) { diff --git a/pkg/util/retry/retry.go b/pkg/util/retry/retry.go index c623bb1dbeb4d..a2a722ec13571 100644 --- a/pkg/util/retry/retry.go +++ b/pkg/util/retry/retry.go @@ -40,7 +40,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { var lastErr error - for i := uint(0); i < c.attempts; i++ { + for i := uint(0); c.attempts == 0 || i < c.attempts; i++ { if err := fn(); err != nil { if i%4 == 0 { log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err)) diff --git a/pkg/util/retry/retry_test.go b/pkg/util/retry/retry_test.go index d0936a70dba85..e4c86d0b7521d 100644 --- a/pkg/util/retry/retry_test.go +++ b/pkg/util/retry/retry_test.go @@ -50,6 +50,17 @@ func TestAttempts(t *testing.T) { err := Do(ctx, testFn, Attempts(1)) assert.Error(t, err) t.Log(err) + + ctx = context.Background() + testOperation := 0 + testFn = func() error { + testOperation++ + return nil + } + + err = Do(ctx, testFn, AttemptAlways()) + assert.Equal(t, testOperation, 1) + assert.NoError(t, err) } func TestMaxSleepTime(t *testing.T) { diff --git a/pkg/util/typeutil/backoff_timer.go b/pkg/util/typeutil/backoff_timer.go index dd26b136fee8d..997ccb2839211 100644 --- a/pkg/util/typeutil/backoff_timer.go +++ b/pkg/util/typeutil/backoff_timer.go @@ -94,3 +94,49 @@ func (t *BackoffTimer) NextInterval() time.Duration { } return t.configFetcher.DefaultInterval() } + +// NewBackoffWithInstant creates a new backoff with instant +func NewBackoffWithInstant(fetcher BackoffTimerConfigFetcher) *BackoffWithInstant { + cfg := fetcher.BackoffConfig() + defaultInterval := fetcher.DefaultInterval() + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = cfg.InitialInterval + backoff.Multiplier = cfg.Multiplier + backoff.MaxInterval = cfg.MaxInterval + backoff.MaxElapsedTime = defaultInterval + backoff.Stop = defaultInterval + backoff.Reset() + return &BackoffWithInstant{ + backoff: backoff, + nextInstant: time.Now(), + } +} + +// BackoffWithInstant is a backoff with instant. +// A instant can be recorded with `UpdateInstantWithNextBackOff` +// NextInstant can be used to make priority decision. +type BackoffWithInstant struct { + backoff *backoff.ExponentialBackOff + nextInstant time.Time +} + +// NextInstant returns the next instant +func (t *BackoffWithInstant) NextInstant() time.Time { + return t.nextInstant +} + +// NextInterval returns the next interval +func (t *BackoffWithInstant) NextInterval() time.Duration { + return time.Until(t.nextInstant) +} + +// NextTimer returns the next timer and the duration of the timer +func (t *BackoffWithInstant) NextTimer() (<-chan time.Time, time.Duration) { + next := time.Until(t.nextInstant) + return time.After(next), next +} + +// UpdateInstantWithNextBackOff updates the next instant with next backoff +func (t *BackoffWithInstant) UpdateInstantWithNextBackOff() { + t.nextInstant = time.Now().Add(t.backoff.NextBackOff()) +}