From a80189f8af7193273bbc270d476ed3b11a5b39ff Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Fri, 24 Jan 2025 11:39:02 +0800 Subject: [PATCH 1/6] enhance: reduce stats task cost by skipping ser/de Signed-off-by: Ted Xu --- internal/datanode/compaction/merge_sort.go | 92 ++-- .../compaction/segment_reader_from_binlogs.go | 86 --- .../compaction/segment_record_reader.go | 31 ++ internal/indexnode/task_stats.go | 510 +++++++++--------- internal/indexnode/task_stats_test.go | 10 +- internal/storage/binlog_iterator_test.go | 6 +- internal/storage/serde.go | 52 +- internal/storage/serde_events.go | 57 +- internal/storage/serde_test.go | 19 + internal/storage/sort.go | 220 ++++++++ internal/storage/sort_test.go | 127 +++++ 11 files changed, 777 insertions(+), 433 deletions(-) delete mode 100644 internal/datanode/compaction/segment_reader_from_binlogs.go create mode 100644 internal/datanode/compaction/segment_record_reader.go create mode 100644 internal/storage/sort.go create mode 100644 internal/storage/sort_test.go diff --git a/internal/datanode/compaction/merge_sort.go b/internal/datanode/compaction/merge_sort.go index 0758d660ec527..48139b74cc231 100644 --- a/internal/datanode/compaction/merge_sort.go +++ b/internal/datanode/compaction/merge_sort.go @@ -1,20 +1,21 @@ package compaction import ( - "container/heap" "context" "fmt" - sio "io" "math" "time" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/flushcommon/io" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/proto/datapb" @@ -22,6 +23,24 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +type segmentWriterWrapper struct { + *MultiSegmentWriter +} + +var _ storage.RecordWriter = (*segmentWriterWrapper)(nil) + +func (w *segmentWriterWrapper) GetWrittenUncompressed() uint64 { + return 0 +} + +func (w *segmentWriterWrapper) Write(record storage.Record) error { + return w.MultiSegmentWriter.WriteRecord(record) +} + +func (w *segmentWriterWrapper) Close() error { + return nil +} + func mergeSortMultipleSegments(ctx context.Context, plan *datapb.CompactionPlan, collectionID, partitionID, maxRows int64, @@ -43,6 +62,7 @@ func mergeSortMultipleSegments(ctx context.Context, logIDAlloc := allocator.NewLocalAllocator(plan.GetBeginLogID(), math.MaxInt64) compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc) mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan, maxRows, partitionID, collectionID, bm25FieldIds) + writer := &segmentWriterWrapper{MultiSegmentWriter: mWriter} pkField, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema()) if err != nil { @@ -50,8 +70,7 @@ func mergeSortMultipleSegments(ctx context.Context, return nil, err } - // SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID()) - segmentReaders := make([]*SegmentDeserializeReader, len(binlogs)) + segmentReaders := make([]storage.RecordReader, len(binlogs)) segmentFilters := make([]*EntityFilter, len(binlogs)) for i, s := range binlogs { var binlogBatchCount int @@ -75,7 +94,7 @@ func mergeSortMultipleSegments(ctx context.Context, } binlogPaths[idx] = batchPaths } - segmentReaders[i] = NewSegmentDeserializeReader(ctx, binlogPaths, binlogIO, pkField.GetFieldID(), bm25FieldIds) + segmentReaders[i] = NewSegmentRecordReader(ctx, binlogPaths, binlogIO) deltalogPaths := make([]string, 0) for _, d := range s.GetDeltalogs() { for _, l := range d.GetBinlogs() { @@ -89,57 +108,26 @@ func mergeSortMultipleSegments(ctx context.Context, segmentFilters[i] = newEntityFilter(delta, collectionTtl, currentTime) } - advanceRow := func(i int) (*storage.Value, error) { - for { - v, err := segmentReaders[i].Next() - if err != nil { - return nil, err - } - - if segmentFilters[i].Filtered(v.PK.GetValue(), uint64(v.Timestamp)) { - continue - } - - return v, nil + var predicate func(r storage.Record, ri, i int) bool + switch pkField.DataType { + case schemapb.DataType_Int64: + predicate = func(r storage.Record, ri, i int) bool { + pk := r.Column(pkField.FieldID).(*array.Int64).Value(i) + ts := r.Column(common.TimeStampField).(*array.Int64).Value(i) + return segmentFilters[ri].Filtered(pk, uint64(ts)) } - } - - pq := make(PriorityQueue, 0) - heap.Init(&pq) - - for i := range segmentReaders { - v, err := advanceRow(i) - if err != nil { - log.Warn("compact wrong, failed to advance row", zap.Error(err)) - return nil, err + case schemapb.DataType_VarChar: + predicate = func(r storage.Record, ri, i int) bool { + pk := r.Column(pkField.FieldID).(*array.String).Value(i) + ts := r.Column(common.TimeStampField).(*array.Int64).Value(i) + return segmentFilters[ri].Filtered(pk, uint64(ts)) } - heap.Push(&pq, &PQItem{ - Value: v, - Index: i, - }) + default: + log.Warn("compaction only support int64 and varchar pk field") } - for pq.Len() > 0 { - smallest := heap.Pop(&pq).(*PQItem) - v := smallest.Value - - err := mWriter.Write(v) - if err != nil { - log.Warn("compact wrong, failed to writer row", zap.Error(err)) - return nil, err - } - - iv, err := advanceRow(smallest.Index) - if err != nil && err != sio.EOF { - return nil, err - } - if err == nil { - next := &PQItem{ - Value: iv, - Index: smallest.Index, - } - heap.Push(&pq, next) - } + if _, err = storage.MergeSort(segmentReaders, pkField.FieldID, writer, predicate); err != nil { + return nil, err } res, err := mWriter.Finish() diff --git a/internal/datanode/compaction/segment_reader_from_binlogs.go b/internal/datanode/compaction/segment_reader_from_binlogs.go deleted file mode 100644 index 5d5fc3636be0a..0000000000000 --- a/internal/datanode/compaction/segment_reader_from_binlogs.go +++ /dev/null @@ -1,86 +0,0 @@ -package compaction - -import ( - "context" - "io" - - "github.com/samber/lo" - "go.uber.org/zap" - - binlogIO "github.com/milvus-io/milvus/internal/flushcommon/io" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/log" -) - -type SegmentDeserializeReader struct { - ctx context.Context - binlogIO binlogIO.BinlogIO - reader *storage.DeserializeReader[*storage.Value] - - pos int - PKFieldID int64 - binlogPaths [][]string - binlogPathPos int - - bm25FieldIDs []int64 -} - -func NewSegmentDeserializeReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO, PKFieldID int64, bm25FieldIDs []int64) *SegmentDeserializeReader { - return &SegmentDeserializeReader{ - ctx: ctx, - binlogIO: binlogIO, - reader: nil, - pos: 0, - PKFieldID: PKFieldID, - binlogPaths: binlogPaths, - binlogPathPos: 0, - bm25FieldIDs: bm25FieldIDs, - } -} - -func (r *SegmentDeserializeReader) initDeserializeReader() error { - if r.binlogPathPos >= len(r.binlogPaths) { - return io.EOF - } - allValues, err := r.binlogIO.Download(r.ctx, r.binlogPaths[r.binlogPathPos]) - if err != nil { - log.Warn("compact wrong, fail to download insertLogs", zap.Error(err)) - return err - } - - blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob { - return &storage.Blob{Key: r.binlogPaths[r.binlogPathPos][i], Value: v} - }) - - r.reader, err = storage.NewBinlogDeserializeReader(blobs, r.PKFieldID) - if err != nil { - log.Warn("compact wrong, failed to new insert binlogs reader", zap.Error(err)) - return err - } - r.binlogPathPos++ - return nil -} - -func (r *SegmentDeserializeReader) Next() (*storage.Value, error) { - if r.reader == nil { - if err := r.initDeserializeReader(); err != nil { - return nil, err - } - } - if err := r.reader.Next(); err != nil { - if err == io.EOF { - r.reader.Close() - if err := r.initDeserializeReader(); err != nil { - return nil, err - } - err = r.reader.Next() - return r.reader.Value(), err - } - return nil, err - } - return r.reader.Value(), nil -} - -func (r *SegmentDeserializeReader) Close() { - r.reader.Close() -} diff --git a/internal/datanode/compaction/segment_record_reader.go b/internal/datanode/compaction/segment_record_reader.go new file mode 100644 index 0000000000000..3e0897d0ceac3 --- /dev/null +++ b/internal/datanode/compaction/segment_record_reader.go @@ -0,0 +1,31 @@ +package compaction + +import ( + "context" + "io" + + "github.com/samber/lo" + + binlogIO "github.com/milvus-io/milvus/internal/flushcommon/io" + "github.com/milvus-io/milvus/internal/storage" +) + +func NewSegmentRecordReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO) storage.RecordReader { + pos := 0 + return &storage.CompositeBinlogRecordReader{ + BlobsReader: func() ([]*storage.Blob, error) { + if pos >= len(binlogPaths) { + return nil, io.EOF + } + bytesArr, err := binlogIO.Download(ctx, binlogPaths[pos]) + if err != nil { + return nil, err + } + pos++ + blobs := lo.Map(bytesArr, func(v []byte, i int) *storage.Blob { + return &storage.Blob{Key: binlogPaths[pos][i], Value: v} + }) + return blobs, nil + }, + } +} diff --git a/internal/indexnode/task_stats.go b/internal/indexnode/task_stats.go index 68c5409b32f57..65d0c63f29a2b 100644 --- a/internal/indexnode/task_stats.go +++ b/internal/indexnode/task_stats.go @@ -19,22 +19,23 @@ package indexnode import ( "context" "fmt" - sio "io" - "sort" "strconv" "time" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" "google.golang.org/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/compaction" iter "github.com/milvus-io/milvus/internal/datanode/iterators" "github.com/milvus-io/milvus/internal/flushcommon/io" "github.com/milvus-io/milvus/internal/metastore/kv/binlog" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/proto/datapb" "github.com/milvus-io/milvus/pkg/proto/indexcgopb" @@ -155,139 +156,314 @@ func (st *statsTask) PreExecute(ctx context.Context) error { return nil } -func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, error) { - numRows := st.req.GetNumRows() +// segmentRecordWriter is a wrapper of SegmentWriter to implement RecordWriter interface +type segmentRecordWriter struct { + sw *compaction.SegmentWriter + binlogMaxSize uint64 + rootPath string + logID int64 + maxLogID int64 + binlogIO io.BinlogIO + ctx context.Context + numRows int64 + bm25FieldIds []int64 + + lastUploads []*conc.Future[any] + binlogs map[typeutil.UniqueID]*datapb.FieldBinlog + statslog *datapb.FieldBinlog + bm25statslog []*datapb.FieldBinlog +} - bm25FieldIds := compaction.GetBM25FieldIDs(st.req.GetSchema()) - writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, statsBatchSize, st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds) +var _ storage.RecordWriter = (*segmentRecordWriter)(nil) + +func (srw *segmentRecordWriter) Close() error { + if !srw.sw.FlushAndIsEmpty() { + if err := srw.upload(); err != nil { + return err + } + if err := srw.waitLastUpload(); err != nil { + return err + } + } + + statslog, err := srw.statSerializeWrite() if err != nil { - log.Ctx(ctx).Warn("sort segment wrong, unable to init segment writer", - zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err + log.Ctx(srw.ctx).Warn("stats wrong, failed to serialize write segment stats", + zap.Int64("remaining row count", srw.numRows), zap.Error(err)) + return err } + srw.statslog = statslog + srw.logID++ - var ( - flushBatchCount int // binlog batch count + if len(srw.bm25FieldIds) > 0 { + binlogNums, bm25StatsLogs, err := srw.bm25SerializeWrite() + if err != nil { + log.Ctx(srw.ctx).Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err)) + return err + } + srw.logID += binlogNums + srw.bm25statslog = bm25StatsLogs + } - allBinlogs = make(map[typeutil.UniqueID]*datapb.FieldBinlog) // All binlog meta of a segment - uploadFutures = make([]*conc.Future[any], 0) + return nil +} - downloadCost time.Duration - serWriteTimeCost time.Duration - sortTimeCost time.Duration - ) +func (srw *segmentRecordWriter) GetWrittenUncompressed() uint64 { + return srw.sw.WrittenMemorySize() +} - downloadStart := time.Now() - values, err := st.downloadData(ctx, numRows, writer.GetPkID(), bm25FieldIds) +func (srw *segmentRecordWriter) Write(r storage.Record) error { + err := srw.sw.WriteRecord(r) if err != nil { - log.Ctx(ctx).Warn("download data failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err + return err } - downloadCost = time.Since(downloadStart) - sortStart := time.Now() - sort.Slice(values, func(i, j int) bool { - return values[i].PK.LT(values[j].PK) - }) - sortTimeCost += time.Since(sortStart) + if srw.sw.IsFullWithBinlogMaxSize(srw.binlogMaxSize) { + return srw.upload() + } + return nil +} - for i, v := range values { - err := writer.Write(v) - if err != nil { - log.Ctx(ctx).Warn("write value wrong, failed to writer row", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err - } +func (srw *segmentRecordWriter) upload() error { + if err := srw.waitLastUpload(); err != nil { + return err + } + binlogNum, kvs, partialBinlogs, err := serializeWrite(srw.ctx, srw.rootPath, srw.logID, srw.sw) + if err != nil { + return err + } - if (i+1)%statsBatchSize == 0 && writer.IsFullWithBinlogMaxSize(st.req.GetBinlogMaxSize()) { - serWriteStart := time.Now() - binlogNum, kvs, partialBinlogs, err := serializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.req.GetStartLogID()+st.logIDOffset, writer) - if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to serialize writer", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err - } - serWriteTimeCost += time.Since(serWriteStart) - - uploadFutures = append(uploadFutures, st.binlogIO.AsyncUpload(ctx, kvs)...) - mergeFieldBinlogs(allBinlogs, partialBinlogs) - - flushBatchCount++ - st.logIDOffset += binlogNum - if st.req.GetStartLogID()+st.logIDOffset >= st.req.GetEndLogID() { - log.Ctx(ctx).Warn("binlog files too much, log is not enough", zap.Int64("taskID", st.req.GetTaskID()), - zap.Int64("binlog num", binlogNum), zap.Int64("startLogID", st.req.GetStartLogID()), - zap.Int64("endLogID", st.req.GetEndLogID()), zap.Int64("logIDOffset", st.logIDOffset)) - return nil, fmt.Errorf("binlog files too much, log is not enough") + srw.lastUploads = srw.binlogIO.AsyncUpload(srw.ctx, kvs) + if srw.binlogs == nil { + srw.binlogs = make(map[typeutil.UniqueID]*datapb.FieldBinlog) + } + mergeFieldBinlogs(srw.binlogs, partialBinlogs) + + srw.logID += binlogNum + if srw.logID > srw.maxLogID { + return fmt.Errorf("log id exausted") + } + return nil +} + +func (srw *segmentRecordWriter) waitLastUpload() error { + if len(srw.lastUploads) > 0 { + for _, future := range srw.lastUploads { + if _, err := future.Await(); err != nil { + return err } } } + return nil +} - if !writer.FlushAndIsEmpty() { - serWriteStart := time.Now() - binlogNum, kvs, partialBinlogs, err := serializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.req.GetStartLogID()+st.logIDOffset, writer) - if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to serialize writer", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err +func (srw *segmentRecordWriter) statSerializeWrite() (*datapb.FieldBinlog, error) { + ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(srw.ctx, "statslog serializeWrite") + defer span.End() + sblob, err := srw.sw.Finish() + if err != nil { + return nil, err + } + + key, _ := binlog.BuildLogPathWithRootPath(srw.rootPath, storage.StatsBinlog, + srw.sw.GetCollectionID(), srw.sw.GetPartitionID(), srw.sw.GetSegmentID(), srw.sw.GetPkID(), srw.logID) + kvs := map[string][]byte{key: sblob.GetValue()} + statFieldLog := &datapb.FieldBinlog{ + FieldID: srw.sw.GetPkID(), + Binlogs: []*datapb.Binlog{ + { + LogSize: int64(len(sblob.GetValue())), + MemorySize: int64(len(sblob.GetValue())), + LogPath: key, + EntriesNum: srw.numRows, + }, + }, + } + if err := srw.binlogIO.Upload(ctx, kvs); err != nil { + log.Ctx(ctx).Warn("failed to upload insert log", zap.Error(err)) + return nil, err + } + + return statFieldLog, nil +} + +func (srw *segmentRecordWriter) bm25SerializeWrite() (int64, []*datapb.FieldBinlog, error) { + ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(srw.ctx, "bm25log serializeWrite") + defer span.End() + writer := srw.sw + stats, err := writer.GetBm25StatsBlob() + if err != nil { + return 0, nil, err + } + + kvs := make(map[string][]byte) + binlogs := []*datapb.FieldBinlog{} + cnt := int64(0) + for fieldID, blob := range stats { + key, _ := binlog.BuildLogPathWithRootPath(srw.rootPath, storage.BM25Binlog, + writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, srw.logID) + kvs[key] = blob.GetValue() + fieldLog := &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{ + { + LogSize: int64(len(blob.GetValue())), + MemorySize: int64(len(blob.GetValue())), + LogPath: key, + EntriesNum: srw.numRows, + }, + }, } - serWriteTimeCost += time.Since(serWriteStart) - st.logIDOffset += binlogNum - uploadFutures = append(uploadFutures, st.binlogIO.AsyncUpload(ctx, kvs)...) - mergeFieldBinlogs(allBinlogs, partialBinlogs) - flushBatchCount++ + binlogs = append(binlogs, fieldLog) + srw.logID++ + cnt++ } - err = conc.AwaitAll(uploadFutures...) + if err := srw.binlogIO.Upload(ctx, kvs); err != nil { + log.Ctx(ctx).Warn("failed to upload bm25 log", zap.Error(err)) + return 0, nil, err + } + + return cnt, binlogs, nil +} + +func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) { + numRows := st.req.GetNumRows() + + bm25FieldIds := compaction.GetBM25FieldIDs(st.req.GetSchema()) + pkField, err := typeutil.GetPrimaryFieldSchema(st.req.GetSchema()) + if err != nil { + return nil, err + } + pkFieldID := pkField.FieldID + writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, statsBatchSize, + st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds) if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to upload kvs", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) + log.Ctx(ctx).Warn("sort segment wrong, unable to init segment writer", + zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) return nil, err } + srw := &segmentRecordWriter{ + sw: writer, + binlogMaxSize: st.req.GetBinlogMaxSize(), + rootPath: st.req.GetStorageConfig().GetRootPath(), + logID: st.req.StartLogID, + maxLogID: st.req.EndLogID, + binlogIO: st.binlogIO, + ctx: ctx, + numRows: st.req.NumRows, + bm25FieldIds: bm25FieldIds, + } - serWriteStart := time.Now() - binlogNums, sPath, err := statSerializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows) + log := log.Ctx(ctx).With( + zap.String("clusterID", st.req.GetClusterID()), + zap.Int64("taskID", st.req.GetTaskID()), + zap.Int64("collectionID", st.req.GetCollectionID()), + zap.Int64("partitionID", st.req.GetPartitionID()), + zap.Int64("segmentID", st.req.GetSegmentID()), + zap.Int64s("bm25Fields", bm25FieldIds), + ) + + deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs) if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to serialize write segment stats", zap.Int64("taskID", st.req.GetTaskID()), - zap.Int64("remaining row count", numRows), zap.Error(err)) + log.Warn("load deletePKs failed", zap.Error(err)) return nil, err } - serWriteTimeCost += time.Since(serWriteStart) - st.logIDOffset += binlogNums + var ( + remainingRowCount int64 // the number of remaining entities + expiredRowCount int64 // the number of expired entities + ) + + var isValueValid func(r storage.Record, ri, i int) bool + switch pkField.DataType { + case schemapb.DataType_Int64: + isValueValid = func(r storage.Record, ri, i int) bool { + v := r.Column(pkFieldID).(*array.Int64).Value(i) + ts, ok := deletePKs[v] + if ok && uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) < ts { + return false + } + return !st.isExpiredEntity(ts) + } + case schemapb.DataType_VarChar: + isValueValid = func(r storage.Record, ri, i int) bool { + v := r.Column(pkFieldID).(*array.String).Value(i) + ts, ok := deletePKs[v] + if ok && uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) < ts { + return false + } + return !st.isExpiredEntity(ts) + } + } + + downloadTimeCost := time.Duration(0) + + rrs := make([]storage.RecordReader, len(st.insertLogs)) - var bm25StatsLogs []*datapb.FieldBinlog - if len(bm25FieldIds) > 0 { - binlogNums, bm25StatsLogs, err = bm25SerializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows) + for i, paths := range st.insertLogs { + log := log.With(zap.Strings("paths", paths)) + downloadStart := time.Now() + allValues, err := st.binlogIO.Download(ctx, paths) if err != nil { - log.Ctx(ctx).Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err)) + log.Warn("download wrong, fail to download insertLogs", zap.Error(err)) return nil, err } - st.logIDOffset += binlogNums + downloadTimeCost += time.Since(downloadStart) - if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil { + blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob { + return &storage.Blob{Key: paths[i], Value: v} + }) + + rr, err := storage.NewCompositeBinlogRecordReader(blobs) + if err != nil { + log.Warn("downloadData wrong, failed to new insert binlogs reader", zap.Error(err)) return nil, err } + rrs[i] = rr } - totalElapse := st.tr.RecordSpan() + log.Info("download data success", + zap.Int64("old rows", numRows), + zap.Int64("remainingRowCount", remainingRowCount), + zap.Int64("expiredRowCount", expiredRowCount), + zap.Duration("download binlogs elapse", downloadTimeCost), + ) + + numValidRows, err := storage.Sort(rrs, writer.GetPkID(), srw, isValueValid) + if err != nil { + log.Warn("sort failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) + return nil, err + } + if err := srw.Close(); err != nil { + return nil, err + } - insertLogs := lo.Values(allBinlogs) + insertLogs := lo.Values(srw.binlogs) if err := binlog.CompressFieldBinlogs(insertLogs); err != nil { return nil, err } - statsLogs := []*datapb.FieldBinlog{sPath} + statsLogs := []*datapb.FieldBinlog{srw.statslog} if err := binlog.CompressFieldBinlogs(statsLogs); err != nil { return nil, err } + bm25StatsLogs := srw.bm25statslog + if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil { + return nil, err + } + st.node.storePKSortStatsResult(st.req.GetClusterID(), st.req.GetTaskID(), st.req.GetCollectionID(), st.req.GetPartitionID(), st.req.GetTargetSegmentID(), st.req.GetInsertChannel(), - int64(len(values)), insertLogs, statsLogs, bm25StatsLogs) + int64(numValidRows), insertLogs, statsLogs, bm25StatsLogs) - log.Ctx(ctx).Info("sort segment end", + log.Info("sort segment end", zap.String("clusterID", st.req.GetClusterID()), zap.Int64("taskID", st.req.GetTaskID()), zap.Int64("collectionID", st.req.GetCollectionID()), @@ -296,12 +472,7 @@ func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, er zap.String("subTaskType", st.req.GetSubJobType().String()), zap.Int64("target segmentID", st.req.GetTargetSegmentID()), zap.Int64("old rows", numRows), - zap.Int("valid rows", len(values)), - zap.Int("binlog batch count", flushBatchCount), - zap.Duration("download elapse", downloadCost), - zap.Duration("sort elapse", sortTimeCost), - zap.Duration("serWrite elapse", serWriteTimeCost), - zap.Duration("total elapse", totalElapse)) + zap.Int("valid rows", numValidRows)) return insertLogs, nil } @@ -313,7 +484,7 @@ func (st *statsTask) Execute(ctx context.Context) error { insertLogs := st.req.GetInsertLogs() var err error if st.req.GetSubJobType() == indexpb.StatsSubJob_Sort { - insertLogs, err = st.sortSegment(ctx) + insertLogs, err = st.sort(ctx) if err != nil { return err } @@ -350,99 +521,6 @@ func (st *statsTask) Reset() { st.node = nil } -func (st *statsTask) downloadData(ctx context.Context, numRows int64, PKFieldID int64, bm25FieldIds []int64) ([]*storage.Value, error) { - log := log.Ctx(ctx).With( - zap.String("clusterID", st.req.GetClusterID()), - zap.Int64("taskID", st.req.GetTaskID()), - zap.Int64("collectionID", st.req.GetCollectionID()), - zap.Int64("partitionID", st.req.GetPartitionID()), - zap.Int64("segmentID", st.req.GetSegmentID()), - zap.Int64s("bm25Fields", bm25FieldIds), - ) - - deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs) - if err != nil { - log.Warn("load deletePKs failed", zap.Error(err)) - return nil, err - } - - var ( - remainingRowCount int64 // the number of remaining entities - expiredRowCount int64 // the number of expired entities - ) - - isValueDeleted := func(v *storage.Value) bool { - ts, ok := deletePKs[v.PK.GetValue()] - // insert task and delete task has the same ts when upsert - // here should be < instead of <= - // to avoid the upsert data to be deleted after compact - if ok && uint64(v.Timestamp) < ts { - return true - } - return false - } - - downloadTimeCost := time.Duration(0) - - values := make([]*storage.Value, 0, numRows) - for _, paths := range st.insertLogs { - log := log.With(zap.Strings("paths", paths)) - downloadStart := time.Now() - allValues, err := st.binlogIO.Download(ctx, paths) - if err != nil { - log.Warn("download wrong, fail to download insertLogs", zap.Error(err)) - return nil, err - } - downloadTimeCost += time.Since(downloadStart) - - blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob { - return &storage.Blob{Key: paths[i], Value: v} - }) - - iter, err := storage.NewBinlogDeserializeReader(blobs, PKFieldID) - if err != nil { - log.Warn("downloadData wrong, failed to new insert binlogs reader", zap.Error(err)) - return nil, err - } - - for { - err := iter.Next() - if err != nil { - if err == sio.EOF { - break - } else { - log.Warn("downloadData wrong, failed to iter through data", zap.Error(err)) - iter.Close() - return nil, err - } - } - - v := iter.Value() - if isValueDeleted(v) { - continue - } - - // Filtering expired entity - if st.isExpiredEntity(typeutil.Timestamp(v.Timestamp)) { - expiredRowCount++ - continue - } - - values = append(values, iter.Value()) - remainingRowCount++ - } - iter.Close() - } - - log.Info("download data success", - zap.Int64("old rows", numRows), - zap.Int64("remainingRowCount", remainingRowCount), - zap.Int64("expiredRowCount", expiredRowCount), - zap.Duration("download binlogs elapse", downloadTimeCost), - ) - return values, nil -} - func (st *statsTask) loadDeltalogs(ctx context.Context, dpaths []string) (map[interface{}]typeutil.Timestamp, error) { st.tr.RecordSpan() ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "loadDeltalogs") @@ -545,74 +623,6 @@ func serializeWrite(ctx context.Context, rootPath string, startID int64, writer return } -func statSerializeWrite(ctx context.Context, rootPath string, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, *datapb.FieldBinlog, error) { - ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "statslog serializeWrite") - defer span.End() - sblob, err := writer.Finish() - if err != nil { - return 0, nil, err - } - - binlogNum := int64(1) - key, _ := binlog.BuildLogPathWithRootPath(rootPath, storage.StatsBinlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), writer.GetPkID(), startID) - kvs := map[string][]byte{key: sblob.GetValue()} - statFieldLog := &datapb.FieldBinlog{ - FieldID: writer.GetPkID(), - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(sblob.GetValue())), - MemorySize: int64(len(sblob.GetValue())), - LogPath: key, - EntriesNum: finalRowCount, - }, - }, - } - if err := io.Upload(ctx, kvs); err != nil { - log.Ctx(ctx).Warn("failed to upload insert log", zap.Error(err)) - return binlogNum, nil, err - } - - return binlogNum, statFieldLog, nil -} - -func bm25SerializeWrite(ctx context.Context, rootPath string, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, []*datapb.FieldBinlog, error) { - ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "bm25log serializeWrite") - defer span.End() - stats, err := writer.GetBm25StatsBlob() - if err != nil { - return 0, nil, err - } - - kvs := make(map[string][]byte) - binlogs := []*datapb.FieldBinlog{} - cnt := int64(0) - for fieldID, blob := range stats { - key, _ := binlog.BuildLogPathWithRootPath(rootPath, storage.BM25Binlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, startID+cnt) - kvs[key] = blob.GetValue() - fieldLog := &datapb.FieldBinlog{ - FieldID: fieldID, - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(blob.GetValue())), - MemorySize: int64(len(blob.GetValue())), - LogPath: key, - EntriesNum: finalRowCount, - }, - }, - } - - binlogs = append(binlogs, fieldLog) - cnt++ - } - - if err := io.Upload(ctx, kvs); err != nil { - log.Ctx(ctx).Warn("failed to upload bm25 log", zap.Error(err)) - return 0, nil, err - } - - return cnt, binlogs, nil -} - func ParseStorageConfig(s *indexpb.StorageConfig) (*indexcgopb.StorageConfig, error) { bs, err := proto.Marshal(s) if err != nil { diff --git a/internal/indexnode/task_stats_test.go b/internal/indexnode/task_stats_test.go index 6243301578bf4..cff3b4389b7b0 100644 --- a/internal/indexnode/task_stats_test.go +++ b/internal/indexnode/task_stats_test.go @@ -130,10 +130,13 @@ func (s *TaskStatsSuite) TestSortSegmentWithBM25() { InsertLogs: lo.Values(fBinlogs), Schema: s.schema, NumRows: 1, + StartLogID: 0, + EndLogID: 5, + BinlogMaxSize: 64 * 1024 * 1024, }, node, s.mockBinlogIO) err = task.PreExecute(ctx) s.Require().NoError(err) - binlog, err := task.sortSegment(ctx) + binlog, err := task.sort(ctx) s.Require().NoError(err) s.Equal(5, len(binlog)) @@ -174,10 +177,13 @@ func (s *TaskStatsSuite) TestSortSegmentWithBM25() { InsertLogs: lo.Values(fBinlogs), Schema: s.schema, NumRows: 1, + StartLogID: 0, + EndLogID: 5, + BinlogMaxSize: 64 * 1024 * 1024, }, node, s.mockBinlogIO) err = task.PreExecute(ctx) s.Require().NoError(err) - _, err = task.sortSegment(ctx) + _, err = task.sort(ctx) s.Error(err) }) } diff --git a/internal/storage/binlog_iterator_test.go b/internal/storage/binlog_iterator_test.go index ebe4f478e3067..5365fc3662886 100644 --- a/internal/storage/binlog_iterator_test.go +++ b/internal/storage/binlog_iterator_test.go @@ -66,6 +66,10 @@ func generateTestSchema() *schemapb.CollectionSchema { } func generateTestData(num int) ([]*Blob, error) { + return generateTestDataWithSeed(1, num) +} + +func generateTestDataWithSeed(seed, num int) ([]*Blob, error) { insertCodec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ID: 1, Schema: generateTestSchema()}) var ( @@ -92,7 +96,7 @@ func generateTestData(num int) ([]*Blob, error) { field106 [][]byte ) - for i := 1; i <= num; i++ { + for i := seed; i < seed+num; i++ { field0 = append(field0, int64(i)) field1 = append(field1, int64(i)) field10 = append(field10, true) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index ae46ca8bfbb6b..424998b580bd8 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -41,19 +41,20 @@ type Record interface { Column(i FieldID) arrow.Array Len() int Release() + Retain() Slice(start, end int) Record } type RecordReader interface { Next() error Record() Record - Close() + Close() error } type RecordWriter interface { Write(r Record) error GetWrittenUncompressed() uint64 - Close() + Close() error } type ( @@ -86,6 +87,12 @@ func (r *compositeRecord) Release() { } } +func (r *compositeRecord) Retain() { + for _, rec := range r.recs { + rec.Retain() + } +} + func (r *compositeRecord) Schema() map[FieldID]schemapb.DataType { return r.schema } @@ -543,28 +550,20 @@ func (deser *DeserializeReader[T]) Next() error { return nil } -func (deser *DeserializeReader[T]) NextRecord() (Record, error) { - if len(deser.values) != 0 { - return nil, errors.New("deserialize result is not empty") - } - - if err := deser.rr.Next(); err != nil { - return nil, err - } - return deser.rr.Record(), nil -} - func (deser *DeserializeReader[T]) Value() T { return deser.values[deser.pos] } -func (deser *DeserializeReader[T]) Close() { +func (deser *DeserializeReader[T]) Close() error { if deser.rec != nil { deser.rec.Release() } if deser.rr != nil { - deser.rr.Close() + if err := deser.rr.Close(); err != nil { + return err + } } + return nil } func NewDeserializeReader[T any](rr RecordReader, deserializer Deserializer[T]) *DeserializeReader[T] { @@ -607,6 +606,10 @@ func (r *selectiveRecord) Release() { // do nothing. } +func (r *selectiveRecord) Retain() { + // do nothing +} + func (r *selectiveRecord) Slice(start, end int) Record { panic("not implemented") } @@ -703,14 +706,17 @@ func (crw *CompositeRecordWriter) Write(r Record) error { return nil } -func (crw *CompositeRecordWriter) Close() { +func (crw *CompositeRecordWriter) Close() error { if crw != nil { for _, w := range crw.writers { if w != nil { - w.Close() + if err := w.Close(); err != nil { + return err + } } } } + return nil } func NewCompositeRecordWriter(writers map[FieldID]RecordWriter) *CompositeRecordWriter { @@ -753,8 +759,8 @@ func (sfw *singleFieldRecordWriter) GetWrittenUncompressed() uint64 { return sfw.writtenUncompressed } -func (sfw *singleFieldRecordWriter) Close() { - sfw.fw.Close() +func (sfw *singleFieldRecordWriter) Close() error { + return sfw.fw.Close() } func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer, opts ...RecordWriterOptions) (*singleFieldRecordWriter, error) { @@ -804,8 +810,8 @@ func (mfw *multiFieldRecordWriter) GetWrittenUncompressed() uint64 { return mfw.writtenUncompressed } -func (mfw *multiFieldRecordWriter) Close() { - mfw.fw.Close() +func (mfw *multiFieldRecordWriter) Close() error { + return mfw.fw.Close() } func newMultiFieldRecordWriter(fieldIds []FieldID, fields []arrow.Field, writer io.Writer) (*multiFieldRecordWriter, error) { @@ -931,6 +937,10 @@ func (sr *simpleArrowRecord) Release() { sr.r.Release() } +func (sr *simpleArrowRecord) Retain() { + sr.r.Retain() +} + func (sr *simpleArrowRecord) ArrowSchema() *arrow.Schema { return sr.r.Schema() } diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 80f22672c0c77..526df71949366 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -39,10 +39,12 @@ import ( var _ RecordReader = (*CompositeBinlogRecordReader)(nil) +// ChunkedBlobsReader returns a chunk composed of blobs, or io.EOF if no more data +type ChunkedBlobsReader func() ([]*Blob, error) + type CompositeBinlogRecordReader struct { - blobs [][]*Blob + BlobsReader ChunkedBlobsReader - blobPos int rrs []array.RecordReader closers []func() fields []FieldID @@ -58,13 +60,24 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { } } } - crr.blobPos++ - if crr.blobPos >= len(crr.blobs[0]) { - return io.EOF + + blobs, err := crr.BlobsReader() + if err != nil { + return err + } + + if crr.rrs == nil { + crr.rrs = make([]array.RecordReader, len(blobs)) + crr.closers = make([]func(), len(blobs)) + crr.fields = make([]FieldID, len(blobs)) + crr.r = compositeRecord{ + recs: make(map[FieldID]arrow.Record, len(crr.rrs)), + schema: make(map[FieldID]schemapb.DataType, len(crr.rrs)), + } } - for i, b := range crr.blobs { - reader, err := NewBinlogReader(b[crr.blobPos].Value) + for i, b := range blobs { + reader, err := NewBinlogReader(b.Value) if err != nil { return err } @@ -92,17 +105,6 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { func (crr *CompositeBinlogRecordReader) Next() error { if crr.rrs == nil { - if crr.blobs == nil || len(crr.blobs) == 0 { - return io.EOF - } - crr.rrs = make([]array.RecordReader, len(crr.blobs)) - crr.closers = make([]func(), len(crr.blobs)) - crr.blobPos = -1 - crr.fields = make([]FieldID, len(crr.rrs)) - crr.r = compositeRecord{ - recs: make(map[FieldID]arrow.Record, len(crr.rrs)), - schema: make(map[FieldID]schemapb.DataType, len(crr.rrs)), - } if err := crr.iterateNextBatch(); err != nil { return err } @@ -138,12 +140,13 @@ func (crr *CompositeBinlogRecordReader) Record() Record { return &crr.r } -func (crr *CompositeBinlogRecordReader) Close() { +func (crr *CompositeBinlogRecordReader) Close() error { for _, close := range crr.closers { if close != nil { close() } } + return nil } func parseBlobKey(blobKey string) (colId FieldID, logId UniqueID) { @@ -177,8 +180,19 @@ func NewCompositeBinlogRecordReader(blobs []*Blob) (*CompositeBinlogRecordReader }) sortedBlobs = append(sortedBlobs, blobsForField) } + chunkPos := 0 return &CompositeBinlogRecordReader{ - blobs: sortedBlobs, + BlobsReader: func() ([]*Blob, error) { + if chunkPos >= len(sortedBlobs[0]) { + return nil, io.EOF + } + blobs := make([]*Blob, len(sortedBlobs)) + for fieldPos := range blobs { + blobs[fieldPos] = sortedBlobs[fieldPos][chunkPos] + } + chunkPos++ + return blobs, nil + }, }, nil } @@ -623,10 +637,11 @@ func (crr *simpleArrowRecordReader) Record() Record { return &crr.r } -func (crr *simpleArrowRecordReader) Close() { +func (crr *simpleArrowRecordReader) Close() error { if crr.closer != nil { crr.closer() } + return nil } func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error) { diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index f31ffe8b49aff..8e56eb06751e3 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -30,6 +30,25 @@ import ( "github.com/milvus-io/milvus/pkg/common" ) +type MockRecordWriter struct { + writefn func(Record) error + closefn func() error +} + +var _ RecordWriter = (*MockRecordWriter)(nil) + +func (w *MockRecordWriter) Write(record Record) error { + return w.writefn(record) +} + +func (w *MockRecordWriter) Close() error { + return w.closefn() +} + +func (w *MockRecordWriter) GetWrittenUncompressed() uint64 { + return 0 +} + func TestSerDe(t *testing.T) { type args struct { dt schemapb.DataType diff --git a/internal/storage/sort.go b/internal/storage/sort.go new file mode 100644 index 0000000000000..af42870cf0406 --- /dev/null +++ b/internal/storage/sort.go @@ -0,0 +1,220 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "container/heap" + "io" + "sort" + + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) { + records := make([]Record, 0) + + type index struct { + ri int + i int + } + indices := make([]index, 0) + + defer func() { + for _, r := range records { + r.Release() + } + }() + + for _, r := range rr { + for { + err := r.Next() + if err == nil { + rec := r.Record() + rec.Retain() + ri := len(records) + records = append(records, rec) + for i := 0; i < rec.Len(); i++ { + if predicate(rec, ri, i) { + indices = append(indices, index{ri, i}) + } + } + } else if err == io.EOF { + break + } else { + return 0, err + } + } + } + + if len(records) == 0 { + return 0, nil + } + + switch records[0].Schema()[pkField] { + case schemapb.DataType_Int64: + sort.Slice(indices, func(i, j int) bool { + pki := records[indices[i].ri].Column(pkField).(*array.Int64).Value(indices[i].i) + pkj := records[indices[j].ri].Column(pkField).(*array.Int64).Value(indices[j].i) + return pki < pkj + }) + case schemapb.DataType_VarChar: + sort.Slice(indices, func(i, j int) bool { + pki := records[indices[i].ri].Column(pkField).(*array.String).Value(indices[i].i) + pkj := records[indices[j].ri].Column(pkField).(*array.String).Value(indices[j].i) + return pki < pkj + }) + } + + for _, i := range indices { + rec := records[i.ri].Slice(i.i, i.i+1) + err := rw.Write(rec) + rec.Release() + if err != nil { + return 0, err + } + } + + return len(indices), nil +} + +// A PriorityQueue implements heap.Interface and holds Items. +type PriorityQueue[T any] struct { + items []*T + less func(x, y *T) bool +} + +var _ heap.Interface = (*PriorityQueue[any])(nil) + +func (pq PriorityQueue[T]) Len() int { return len(pq.items) } + +func (pq PriorityQueue[T]) Less(i, j int) bool { + return pq.less(pq.items[i], pq.items[j]) +} + +func (pq PriorityQueue[T]) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] +} + +func (pq *PriorityQueue[T]) Push(x any) { + pq.items = append(pq.items, x.(*T)) +} + +func (pq *PriorityQueue[T]) Pop() any { + old := pq.items + n := len(old) + x := old[n-1] + pq.items = old[0 : n-1] + return x +} + +func (pq *PriorityQueue[T]) Enqueue(x *T) { + heap.Push(pq, x) +} + +func (pq *PriorityQueue[T]) Dequeue() *T { + return heap.Pop(pq).(*T) +} + +func NewPriorityQueue[T any](less func(x, y *T) bool) *PriorityQueue[T] { + pq := PriorityQueue[T]{ + items: make([]*T, 0), + less: less, + } + heap.Init(&pq) + return &pq +} + +func MergeSort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) { + type index struct { + ri int + i int + } + + advanceRecord := func(r RecordReader) (Record, error) { + err := r.Next() + if err != nil { + return nil, err + } + return r.Record(), nil + } + + recs := make([]Record, len(rr)) + for i, r := range rr { + rec, err := advanceRecord(r) + if err == io.EOF { + recs[i] = nil + continue + } + if err != nil { + return 0, err + } + recs[i] = rec + } + + var pq *PriorityQueue[index] + switch recs[0].Schema()[pkField] { + case schemapb.DataType_Int64: + pq = NewPriorityQueue[index](func(x, y *index) bool { + return rr[x.ri].Record().Column(pkField).(*array.Int64).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.Int64).Value(y.i) + }) + case schemapb.DataType_VarChar: + pq = NewPriorityQueue[index](func(x, y *index) bool { + return rr[x.ri].Record().Column(pkField).(*array.String).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.String).Value(y.i) + }) + } + + enqueueAll := func(ri int, r Record) { + for j := 0; j < r.Len(); j++ { + if predicate(r, ri, j) { + pq.Enqueue(&index{ + ri: ri, + i: j, + }) + numRows++ + } + } + } + + for i, v := range recs { + if v != nil { + enqueueAll(i, v) + } + } + + for pq.Len() > 0 { + idx := pq.Dequeue() + sr := rr[idx.ri].Record().Slice(idx.i, idx.i+1) + err := rw.Write(sr) + if err != nil { + return 0, err + } + + // If poped idx reaches end of segment, advance to next segment + if idx.i == rr[idx.ri].Record().Len()-1 { + rec, err := advanceRecord(rr[idx.ri]) + if err == io.EOF { + continue + } + if err != nil { + return 0, err + } + enqueueAll(idx.ri, rec) + } + } + return numRows, nil +} diff --git a/internal/storage/sort_test.go b/internal/storage/sort_test.go new file mode 100644 index 0000000000000..78c88465e6865 --- /dev/null +++ b/internal/storage/sort_test.go @@ -0,0 +1,127 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "testing" + + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/milvus-io/milvus/pkg/common" + "github.com/stretchr/testify/assert" +) + +func TestSort(t *testing.T) { + getReaders := func() []RecordReader { + blobs, err := generateTestDataWithSeed(10, 3) + assert.NoError(t, err) + reader10, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + blobs, err = generateTestDataWithSeed(20, 3) + assert.NoError(t, err) + reader20, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + rr := []RecordReader{reader20, reader10} + return rr + } + + lastPK := int64(-1) + rw := &MockRecordWriter{ + writefn: func(r Record) error { + pk := r.Column(common.RowIDField).(*array.Int64).Value(0) + assert.Greater(t, pk, lastPK) + lastPK = pk + return nil + }, + + closefn: func() error { + lastPK = int64(-1) + return nil + }, + } + + t.Run("sort", func(t *testing.T) { + gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, 6, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) + + t.Run("sort with predicate", func(t *testing.T) { + gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + pk := r.Column(common.RowIDField).(*array.Int64).Value(i) + return pk >= 20 + }) + assert.NoError(t, err) + assert.Equal(t, 3, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) +} + +func TestMergeSort(t *testing.T) { + getReaders := func() []RecordReader { + blobs, err := generateTestDataWithSeed(10, 3) + assert.NoError(t, err) + reader10, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + blobs, err = generateTestDataWithSeed(20, 3) + assert.NoError(t, err) + reader20, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + rr := []RecordReader{reader20, reader10} + return rr + } + + lastPK := int64(-1) + rw := &MockRecordWriter{ + writefn: func(r Record) error { + pk := r.Column(common.RowIDField).(*array.Int64).Value(0) + assert.Greater(t, pk, lastPK) + lastPK = pk + return nil + }, + + closefn: func() error { + lastPK = int64(-1) + return nil + }, + } + + t.Run("merge sort", func(t *testing.T) { + gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, 6, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) + + t.Run("Sort with predicate", func(t *testing.T) { + gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + pk := r.Column(common.RowIDField).(*array.Int64).Value(i) + return pk >= 20 + }) + assert.NoError(t, err) + assert.Equal(t, 3, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) +} From 59bb6d479e6874f4bdcf4acc4885a3cb4ad44bf2 Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Sun, 26 Jan 2025 13:52:13 +0800 Subject: [PATCH 2/6] reduce memory cost by removing schema in Records Signed-off-by: Ted Xu --- internal/storage/serde.go | 81 ++++++-------------------- internal/storage/serde_events.go | 82 ++++++++++++--------------- internal/storage/serde_events_test.go | 2 +- internal/storage/serde_test.go | 11 +--- internal/storage/sort.go | 39 +++++++------ internal/storage/sort_test.go | 45 +++++++++++++-- 6 files changed, 119 insertions(+), 141 deletions(-) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 424998b580bd8..2273ce3b4d26f 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -18,7 +18,6 @@ package storage import ( "encoding/binary" - "fmt" "io" "math" "sync" @@ -36,8 +35,6 @@ import ( ) type Record interface { - Schema() map[FieldID]schemapb.DataType - ArrowSchema() *arrow.Schema Column(i FieldID) arrow.Array Len() int Release() @@ -64,8 +61,7 @@ type ( // compositeRecord is a record being composed of multiple records, in which each only have 1 column type compositeRecord struct { - recs map[FieldID]arrow.Record - schema map[FieldID]schemapb.DataType + recs map[FieldID]arrow.Record } var _ Record = (*compositeRecord)(nil) @@ -93,26 +89,13 @@ func (r *compositeRecord) Retain() { } } -func (r *compositeRecord) Schema() map[FieldID]schemapb.DataType { - return r.schema -} - -func (r *compositeRecord) ArrowSchema() *arrow.Schema { - var fields []arrow.Field - for _, rec := range r.recs { - fields = append(fields, rec.Schema().Field(0)) - } - return arrow.NewSchema(fields, nil) -} - func (r *compositeRecord) Slice(start, end int) Record { slices := make(map[FieldID]arrow.Record) for i, rec := range r.recs { slices[i] = rec.NewSlice(int64(start), int64(end)) } return &compositeRecord{ - recs: slices, - schema: r.schema, + recs: slices, } } @@ -577,22 +560,12 @@ var _ Record = (*selectiveRecord)(nil) // selectiveRecord is a Record that only contains a single field, reusing existing Record. type selectiveRecord struct { - r Record - selectedFieldId FieldID - - schema map[FieldID]schemapb.DataType -} - -func (r *selectiveRecord) Schema() map[FieldID]schemapb.DataType { - return r.schema -} - -func (r *selectiveRecord) ArrowSchema() *arrow.Schema { - return r.r.ArrowSchema() + r Record + fieldId FieldID } func (r *selectiveRecord) Column(i FieldID) arrow.Array { - if i == r.selectedFieldId { + if i == r.fieldId { return r.r.Column(i) } return nil @@ -660,17 +633,10 @@ func calculateArraySize(a arrow.Array) int { return totalSize } -func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord { - dt, ok := r.Schema()[selectedFieldId] - if !ok { - return nil - } - schema := make(map[FieldID]schemapb.DataType, 1) - schema[selectedFieldId] = dt +func newSelectiveRecord(r Record, selectedFieldId FieldID) Record { return &selectiveRecord{ - r: r, - selectedFieldId: selectedFieldId, - schema: schema, + r: r, + fieldId: selectedFieldId, } } @@ -678,26 +644,19 @@ var _ RecordWriter = (*CompositeRecordWriter)(nil) type CompositeRecordWriter struct { writers map[FieldID]RecordWriter - - writtenUncompressed uint64 } func (crw *CompositeRecordWriter) GetWrittenUncompressed() uint64 { - return crw.writtenUncompressed + s := uint64(0) + for _, w := range crw.writers { + s += w.GetWrittenUncompressed() + } + return s } func (crw *CompositeRecordWriter) Write(r Record) error { - if len(r.Schema()) != len(crw.writers) { - return fmt.Errorf("schema length mismatch %d, expected %d", len(r.Schema()), len(crw.writers)) - } - - var bytes uint64 - for fid := range r.Schema() { - arr := r.Column(fid) - bytes += uint64(calculateArraySize(arr)) - } - crw.writtenUncompressed += bytes for fieldId, w := range crw.writers { + // TODO: if field is not exist, write sr := newSelectiveRecord(r, fieldId) if err := w.Write(sr); err != nil { return err @@ -909,18 +868,13 @@ func NewSerializeRecordWriter[T any](rw RecordWriter, serializer Serializer[T], } type simpleArrowRecord struct { - r arrow.Record - schema map[FieldID]schemapb.DataType + r arrow.Record field2Col map[FieldID]int } var _ Record = (*simpleArrowRecord)(nil) -func (sr *simpleArrowRecord) Schema() map[FieldID]schemapb.DataType { - return sr.schema -} - func (sr *simpleArrowRecord) Column(i FieldID) arrow.Array { colIdx, ok := sr.field2Col[i] if !ok { @@ -947,13 +901,12 @@ func (sr *simpleArrowRecord) ArrowSchema() *arrow.Schema { func (sr *simpleArrowRecord) Slice(start, end int) Record { s := sr.r.NewSlice(int64(start), int64(end)) - return newSimpleArrowRecord(s, sr.schema, sr.field2Col) + return newSimpleArrowRecord(s, sr.field2Col) } -func newSimpleArrowRecord(r arrow.Record, schema map[FieldID]schemapb.DataType, field2Col map[FieldID]int) *simpleArrowRecord { +func newSimpleArrowRecord(r arrow.Record, field2Col map[FieldID]int) *simpleArrowRecord { return &simpleArrowRecord{ r: r, - schema: schema, field2Col: field2Col, } } diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 526df71949366..6b35df317a154 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -45,19 +45,20 @@ type ChunkedBlobsReader func() ([]*Blob, error) type CompositeBinlogRecordReader struct { BlobsReader ChunkedBlobsReader - rrs []array.RecordReader - closers []func() - fields []FieldID + brs []*BinlogReader + rrs []array.RecordReader - r compositeRecord + schema map[FieldID]schemapb.DataType + r *compositeRecord } func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { - if crr.closers != nil { - for _, close := range crr.closers { - if close != nil { - close() - } + if crr.brs != nil { + for _, er := range crr.brs { + er.Close() + } + for _, rr := range crr.rrs { + rr.Release() } } @@ -68,12 +69,8 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { if crr.rrs == nil { crr.rrs = make([]array.RecordReader, len(blobs)) - crr.closers = make([]func(), len(blobs)) - crr.fields = make([]FieldID, len(blobs)) - crr.r = compositeRecord{ - recs: make(map[FieldID]arrow.Record, len(crr.rrs)), - schema: make(map[FieldID]schemapb.DataType, len(crr.rrs)), - } + crr.brs = make([]*BinlogReader, len(blobs)) + crr.schema = make(map[FieldID]schemapb.DataType) } for i, b := range blobs { @@ -82,9 +79,8 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { return err } - crr.fields[i] = reader.FieldID // TODO: assert schema being the same in every blobs - crr.r.schema[reader.FieldID] = reader.PayloadDataType + crr.schema[reader.FieldID] = reader.PayloadDataType er, err := reader.NextEventReader() if err != nil { return err @@ -94,11 +90,7 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { return err } crr.rrs[i] = rr - crr.closers[i] = func() { - rr.Release() - er.Close() - reader.Close() - } + crr.brs[i] = reader } return nil } @@ -111,12 +103,15 @@ func (crr *CompositeBinlogRecordReader) Next() error { } composeRecord := func() bool { + recs := make(map[FieldID]arrow.Record, len(crr.rrs)) for i, rr := range crr.rrs { if ok := rr.Next(); !ok { return false } - // compose record - crr.r.recs[crr.fields[i]] = rr.Record() + recs[crr.brs[i].FieldID] = rr.Record() + } + crr.r = &compositeRecord{ + recs: recs, } return true } @@ -137,15 +132,19 @@ func (crr *CompositeBinlogRecordReader) Next() error { } func (crr *CompositeBinlogRecordReader) Record() Record { - return &crr.r + return crr.r } func (crr *CompositeBinlogRecordReader) Close() error { - for _, close := range crr.closers { - if close != nil { - close() + if crr.brs != nil { + for _, er := range crr.brs { + er.Close() + } + for _, rr := range crr.rrs { + rr.Release() } } + crr.r = nil return nil } @@ -183,7 +182,7 @@ func NewCompositeBinlogRecordReader(blobs []*Blob) (*CompositeBinlogRecordReader chunkPos := 0 return &CompositeBinlogRecordReader{ BlobsReader: func() ([]*Blob, error) { - if chunkPos >= len(sortedBlobs[0]) { + if len(sortedBlobs) == 0 || chunkPos >= len(sortedBlobs[0]) { return nil, io.EOF } blobs := make([]*Blob, len(sortedBlobs)) @@ -203,17 +202,18 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize } return NewDeserializeReader(reader, func(r Record, v []*Value) error { + schema := reader.schema // Note: the return value `Value` is reused. for i := 0; i < r.Len(); i++ { value := v[i] if value == nil { value = &Value{} - value.Value = make(map[FieldID]interface{}, len(r.Schema())) + value.Value = make(map[FieldID]interface{}, len(schema)) v[i] = value } m := value.Value.(map[FieldID]interface{}) - for j, dt := range r.Schema() { + for j, dt := range schema { if r.Column(j).IsNull(i) { m[j] = nil } else { @@ -233,7 +233,7 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize value.ID = rowID value.Timestamp = m[common.TimeStampField].(int64) - pk, err := GenPrimaryKeyByRawData(m[PKfieldID], r.Schema()[PKfieldID]) + pk, err := GenPrimaryKeyByRawData(m[PKfieldID], schema[PKfieldID]) if err != nil { return err } @@ -253,7 +253,7 @@ func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog], e } return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error { var fid FieldID // The only fid from delete file - for k := range r.Schema() { + for k := range reader.schema { fid = k break } @@ -405,7 +405,7 @@ func ValueSerializer(v []*Value, fieldSchema []*schemapb.FieldSchema) (Record, e field2Col[fid] = i i++ } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), nil + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), field2Col), nil } func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, segmentID UniqueID, @@ -543,10 +543,7 @@ func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int field2Col := map[FieldID]int{ 0: 0, } - schema := map[FieldID]schemapb.DataType{ - 0: schemapb.DataType_String, - } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), schema, field2Col), nil + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), field2Col), nil }, batchSize), nil } @@ -597,12 +594,11 @@ func (crr *simpleArrowRecordReader) iterateNextBatch() error { func (crr *simpleArrowRecordReader) Next() error { if crr.rr == nil { - if crr.blobs == nil || len(crr.blobs) == 0 { + if len(crr.blobs) == 0 { return io.EOF } crr.blobPos = -1 crr.r = simpleArrowRecord{ - schema: make(map[FieldID]schemapb.DataType), field2Col: make(map[FieldID]int), } if err := crr.iterateNextBatch(); err != nil { @@ -796,11 +792,7 @@ func newDeltalogMultiFieldWriter(eventWriter *MultiFieldDeltalogStreamWriter, ba common.RowIDField: 0, common.TimeStampField: 1, } - schema := map[FieldID]schemapb.DataType{ - common.RowIDField: pkType, - common.TimeStampField: schemapb.DataType_Int64, - } - return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), schema, field2Col), nil + return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), field2Col), nil }, batchSize), nil } diff --git a/internal/storage/serde_events_test.go b/internal/storage/serde_events_test.go index d97c17fd5e1f9..bb252395bfc69 100644 --- a/internal/storage/serde_events_test.go +++ b/internal/storage/serde_events_test.go @@ -92,7 +92,7 @@ func TestBinlogStreamWriter(t *testing.T) { []arrow.Array{arr}, int64(size), ) - r := newSimpleArrowRecord(ar, map[FieldID]schemapb.DataType{1: schemapb.DataType_Bool}, map[FieldID]int{1: 0}) + r := newSimpleArrowRecord(ar, map[FieldID]int{1: 0}) defer r.Release() err = rw.Write(r) assert.NoError(t, err) diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index 8e56eb06751e3..c8848e4590a14 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -127,27 +127,18 @@ func TestArrowSchema(t *testing.T) { record := array.NewRecord(arrow.NewSchema(fields, nil), []arrow.Array{builder.NewArray()}, 1) t.Run("test composite record", func(t *testing.T) { cr := &compositeRecord{ - recs: make(map[FieldID]arrow.Record, 1), - schema: make(map[FieldID]schemapb.DataType, 1), + recs: make(map[FieldID]arrow.Record, 1), } cr.recs[0] = record - cr.schema[0] = schemapb.DataType_String - expected := arrow.NewSchema(fields, nil) - assert.Equal(t, expected, cr.ArrowSchema()) }) t.Run("test simple arrow record", func(t *testing.T) { cr := &simpleArrowRecord{ r: record, - schema: make(map[FieldID]schemapb.DataType, 1), field2Col: make(map[FieldID]int, 1), } - cr.schema[0] = schemapb.DataType_String expected := arrow.NewSchema(fields, nil) assert.Equal(t, expected, cr.ArrowSchema()) - - sr := newSelectiveRecord(cr, 0) - assert.Equal(t, expected, sr.ArrowSchema()) }) } diff --git a/internal/storage/sort.go b/internal/storage/sort.go index af42870cf0406..eea32df2fe8eb 100644 --- a/internal/storage/sort.go +++ b/internal/storage/sort.go @@ -22,7 +22,6 @@ import ( "sort" "github.com/apache/arrow/go/v12/arrow/array" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) { @@ -32,7 +31,7 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r ri int i int } - indices := make([]index, 0) + indices := make([]*index, 0) defer func() { for _, r := range records { @@ -50,7 +49,13 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r records = append(records, rec) for i := 0; i < rec.Len(); i++ { if predicate(rec, ri, i) { - indices = append(indices, index{ri, i}) + numRows++ + // if len(indices) > 0 && indices[len(indices)-1].ri == ri && indices[len(indices)-1].i+1 == i { + // indices[len(indices)-1].i = i + // } else { + // indices = append(indices, &index{ri, i}) + // } + indices = append(indices, &index{ri, i}) } } } else if err == io.EOF { @@ -65,14 +70,14 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r return 0, nil } - switch records[0].Schema()[pkField] { - case schemapb.DataType_Int64: + switch records[0].Column(pkField).(type) { + case *array.Int64: sort.Slice(indices, func(i, j int) bool { pki := records[indices[i].ri].Column(pkField).(*array.Int64).Value(indices[i].i) pkj := records[indices[j].ri].Column(pkField).(*array.Int64).Value(indices[j].i) return pki < pkj }) - case schemapb.DataType_VarChar: + case *array.String: sort.Slice(indices, func(i, j int) bool { pki := records[indices[i].ri].Column(pkField).(*array.String).Value(indices[i].i) pkj := records[indices[j].ri].Column(pkField).(*array.String).Value(indices[j].i) @@ -80,16 +85,16 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r }) } - for _, i := range indices { + writeOne := func(i *index) error { rec := records[i.ri].Slice(i.i, i.i+1) - err := rw.Write(rec) - rec.Release() - if err != nil { - return 0, err - } + defer rec.Release() + return rw.Write(rec) + } + for _, i := range indices { + writeOne(i) } - return len(indices), nil + return numRows, nil } // A PriorityQueue implements heap.Interface and holds Items. @@ -118,6 +123,7 @@ func (pq *PriorityQueue[T]) Pop() any { old := pq.items n := len(old) x := old[n-1] + old[n-1] = nil pq.items = old[0 : n-1] return x } @@ -167,12 +173,12 @@ func MergeSort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate fu } var pq *PriorityQueue[index] - switch recs[0].Schema()[pkField] { - case schemapb.DataType_Int64: + switch recs[0].Column(pkField).(type) { + case *array.Int64: pq = NewPriorityQueue[index](func(x, y *index) bool { return rr[x.ri].Record().Column(pkField).(*array.Int64).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.Int64).Value(y.i) }) - case schemapb.DataType_VarChar: + case *array.String: pq = NewPriorityQueue[index](func(x, y *index) bool { return rr[x.ri].Record().Column(pkField).(*array.String).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.String).Value(y.i) }) @@ -200,6 +206,7 @@ func MergeSort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate fu idx := pq.Dequeue() sr := rr[idx.ri].Record().Slice(idx.i, idx.i+1) err := rw.Write(sr) + sr.Release() if err != nil { return 0, err } diff --git a/internal/storage/sort_test.go b/internal/storage/sort_test.go index 78c88465e6865..f16b01cca33e1 100644 --- a/internal/storage/sort_test.go +++ b/internal/storage/sort_test.go @@ -20,8 +20,9 @@ import ( "testing" "github.com/apache/arrow/go/v12/arrow/array" - "github.com/milvus-io/milvus/pkg/common" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/common" ) func TestSort(t *testing.T) { @@ -54,7 +55,7 @@ func TestSort(t *testing.T) { } t.Run("sort", func(t *testing.T) { - gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { return true }) assert.NoError(t, err) @@ -64,7 +65,7 @@ func TestSort(t *testing.T) { }) t.Run("sort with predicate", func(t *testing.T) { - gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { pk := r.Column(common.RowIDField).(*array.Int64).Value(i) return pk >= 20 }) @@ -105,7 +106,7 @@ func TestMergeSort(t *testing.T) { } t.Run("merge sort", func(t *testing.T) { - gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { return true }) assert.NoError(t, err) @@ -115,7 +116,7 @@ func TestMergeSort(t *testing.T) { }) t.Run("Sort with predicate", func(t *testing.T) { - gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, i int) bool { + gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { pk := r.Column(common.RowIDField).(*array.Int64).Value(i) return pk >= 20 }) @@ -125,3 +126,37 @@ func TestMergeSort(t *testing.T) { assert.NoError(t, err) }) } + +// Benchmark sort +func BenchmarkSort(b *testing.B) { + batch := 100000 + blobs, err := generateTestDataWithSeed(batch, batch) + assert.NoError(b, err) + reader10, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(b, err) + blobs, err = generateTestDataWithSeed(batch*2+1, batch) + assert.NoError(b, err) + reader20, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(b, err) + rr := []RecordReader{reader20, reader10} + + rw := &MockRecordWriter{ + writefn: func(r Record) error { + return nil + }, + + closefn: func() error { + return nil + }, + } + + b.ResetTimer() + + b.Run("sort", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Sort(rr, common.RowIDField, rw, func(r Record, ri, i int) bool { + return true + }) + } + }) +} From 2d841ad40416bf4e2ef835e29b1d890053e1ee8d Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Sun, 26 Jan 2025 15:47:34 +0800 Subject: [PATCH 3/6] reduce memory allocations Signed-off-by: Ted Xu --- internal/storage/serde.go | 15 +++++++----- internal/storage/serde_events.go | 10 +++++--- internal/storage/serde_test.go | 22 ----------------- internal/storage/sort.go | 42 +++++++++++++++++++++++--------- internal/storage/sort_test.go | 4 +-- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 2273ce3b4d26f..7ecb14d5f5531 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -61,18 +61,19 @@ type ( // compositeRecord is a record being composed of multiple records, in which each only have 1 column type compositeRecord struct { - recs map[FieldID]arrow.Record + index map[FieldID]int16 + recs []arrow.Array } var _ Record = (*compositeRecord)(nil) func (r *compositeRecord) Column(i FieldID) arrow.Array { - return r.recs[i].Column(0) + return r.recs[r.index[i]] } func (r *compositeRecord) Len() int { for _, rec := range r.recs { - return rec.Column(0).Len() + return rec.Len() } return 0 } @@ -90,12 +91,14 @@ func (r *compositeRecord) Retain() { } func (r *compositeRecord) Slice(start, end int) Record { - slices := make(map[FieldID]arrow.Record) + slices := make([]arrow.Array, len(r.index)) for i, rec := range r.recs { - slices[i] = rec.NewSlice(int64(start), int64(end)) + d := array.NewSliceData(rec.Data(), int64(start), int64(end)) + slices[i] = array.MakeFromData(d) } return &compositeRecord{ - recs: slices, + index: r.index, + recs: slices, } } diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 6b35df317a154..e8784481a65cf 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -49,6 +49,7 @@ type CompositeBinlogRecordReader struct { rrs []array.RecordReader schema map[FieldID]schemapb.DataType + index map[FieldID]int16 r *compositeRecord } @@ -71,6 +72,7 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { crr.rrs = make([]array.RecordReader, len(blobs)) crr.brs = make([]*BinlogReader, len(blobs)) crr.schema = make(map[FieldID]schemapb.DataType) + crr.index = make(map[FieldID]int16, len(blobs)) } for i, b := range blobs { @@ -90,6 +92,7 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { return err } crr.rrs[i] = rr + crr.index[reader.FieldID] = int16(i) crr.brs[i] = reader } return nil @@ -103,15 +106,16 @@ func (crr *CompositeBinlogRecordReader) Next() error { } composeRecord := func() bool { - recs := make(map[FieldID]arrow.Record, len(crr.rrs)) + recs := make([]arrow.Array, len(crr.rrs)) for i, rr := range crr.rrs { if ok := rr.Next(); !ok { return false } - recs[crr.brs[i].FieldID] = rr.Record() + recs[i] = rr.Record().Column(0) } crr.r = &compositeRecord{ - recs: recs, + index: crr.index, + recs: recs, } return true } diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index c8848e4590a14..df23ab229e25b 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -120,28 +120,6 @@ func TestSerDe(t *testing.T) { } } -func TestArrowSchema(t *testing.T) { - fields := []arrow.Field{{Name: "1", Type: arrow.BinaryTypes.String, Nullable: true}} - builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String) - builder.AppendValueFromString("1") - record := array.NewRecord(arrow.NewSchema(fields, nil), []arrow.Array{builder.NewArray()}, 1) - t.Run("test composite record", func(t *testing.T) { - cr := &compositeRecord{ - recs: make(map[FieldID]arrow.Record, 1), - } - cr.recs[0] = record - }) - - t.Run("test simple arrow record", func(t *testing.T) { - cr := &simpleArrowRecord{ - r: record, - field2Col: make(map[FieldID]int, 1), - } - expected := arrow.NewSchema(fields, nil) - assert.Equal(t, expected, cr.ArrowSchema()) - }) -} - func BenchmarkDeserializeReader(b *testing.B) { len := 1000000 blobs, err := generateTestData(len) diff --git a/internal/storage/sort.go b/internal/storage/sort.go index eea32df2fe8eb..b31fa65adc318 100644 --- a/internal/storage/sort.go +++ b/internal/storage/sort.go @@ -50,12 +50,11 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r for i := 0; i < rec.Len(); i++ { if predicate(rec, ri, i) { numRows++ - // if len(indices) > 0 && indices[len(indices)-1].ri == ri && indices[len(indices)-1].i+1 == i { - // indices[len(indices)-1].i = i - // } else { - // indices = append(indices, &index{ri, i}) - // } - indices = append(indices, &index{ri, i}) + if len(indices) > 0 && indices[len(indices)-1].ri == ri && indices[len(indices)-1].i+1 == i { + indices[len(indices)-1].i = i + } else { + indices = append(indices, &index{ri, i}) + } } } } else if err == io.EOF { @@ -202,17 +201,36 @@ func MergeSort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate fu } } + ri, istart, iend := -1, -1, -1 for pq.Len() > 0 { idx := pq.Dequeue() - sr := rr[idx.ri].Record().Slice(idx.i, idx.i+1) - err := rw.Write(sr) - sr.Release() - if err != nil { - return 0, err + if ri == idx.ri { + // record end of cache, do nothing + iend = idx.i + 1 + } else { + if ri != -1 { + // record changed, write old one and reset + sr := rr[ri].Record().Slice(istart, iend) + err := rw.Write(sr) + sr.Release() + if err != nil { + return 0, err + } + } + ri = idx.ri + istart = idx.i + iend = idx.i + 1 } - // If poped idx reaches end of segment, advance to next segment + // If poped idx reaches end of segment, invalidate cache and advance to next segment if idx.i == rr[idx.ri].Record().Len()-1 { + sr := rr[ri].Record().Slice(istart, iend) + err := rw.Write(sr) + sr.Release() + if err != nil { + return 0, err + } + ri, istart, iend = -1, -1, -1 rec, err := advanceRecord(rr[idx.ri]) if err == io.EOF { continue diff --git a/internal/storage/sort_test.go b/internal/storage/sort_test.go index f16b01cca33e1..bee5c4f2dcd0a 100644 --- a/internal/storage/sort_test.go +++ b/internal/storage/sort_test.go @@ -115,7 +115,7 @@ func TestMergeSort(t *testing.T) { assert.NoError(t, err) }) - t.Run("Sort with predicate", func(t *testing.T) { + t.Run("merge sort with predicate", func(t *testing.T) { gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { pk := r.Column(common.RowIDField).(*array.Int64).Value(i) return pk >= 20 @@ -129,7 +129,7 @@ func TestMergeSort(t *testing.T) { // Benchmark sort func BenchmarkSort(b *testing.B) { - batch := 100000 + batch := 500000 blobs, err := generateTestDataWithSeed(batch, batch) assert.NoError(b, err) reader10, err := NewCompositeBinlogRecordReader(blobs) From d7c2433a709b27a785a339e782a39919a5f3424f Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Sun, 26 Jan 2025 16:17:34 +0800 Subject: [PATCH 4/6] fix build break Signed-off-by: Ted Xu --- internal/indexnode/task_stats_test.go | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/internal/indexnode/task_stats_test.go b/internal/indexnode/task_stats_test.go index cff3b4389b7b0..e439313c05b9c 100644 --- a/internal/indexnode/task_stats_test.go +++ b/internal/indexnode/task_stats_test.go @@ -81,26 +81,6 @@ func (s *TaskStatsSuite) GenSegmentWriterWithBM25(magic int64) { s.segWriter = segWriter } -func (s *TaskStatsSuite) Testbm25SerializeWriteError() { - s.Run("normal case", func() { - s.schema = genCollectionSchemaWithBM25() - s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil).Once() - s.GenSegmentWriterWithBM25(0) - cnt, binlogs, err := bm25SerializeWrite(context.Background(), "root_path", s.mockBinlogIO, 0, s.segWriter, 1) - s.Require().NoError(err) - s.Equal(int64(1), cnt) - s.Equal(1, len(binlogs)) - }) - - s.Run("upload failed", func() { - s.schema = genCollectionSchemaWithBM25() - s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once() - s.GenSegmentWriterWithBM25(0) - _, _, err := bm25SerializeWrite(context.Background(), "root_path", s.mockBinlogIO, 0, s.segWriter, 1) - s.Error(err) - }) -} - func (s *TaskStatsSuite) TestSortSegmentWithBM25() { s.Run("normal case", func() { s.schema = genCollectionSchemaWithBM25() From be8cc56f9635997b4f6ddba4b510c6742cc6f545 Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Tue, 28 Jan 2025 10:06:46 +0800 Subject: [PATCH 5/6] fix UT fail Signed-off-by: Ted Xu --- internal/datanode/compaction/merge_sort.go | 4 ++-- internal/datanode/compaction/segment_record_reader.go | 2 +- internal/storage/binlog_reader.go | 3 +++ internal/storage/serde_events.go | 10 ++++++++-- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/internal/datanode/compaction/merge_sort.go b/internal/datanode/compaction/merge_sort.go index 48139b74cc231..05f1fe9130fd4 100644 --- a/internal/datanode/compaction/merge_sort.go +++ b/internal/datanode/compaction/merge_sort.go @@ -114,13 +114,13 @@ func mergeSortMultipleSegments(ctx context.Context, predicate = func(r storage.Record, ri, i int) bool { pk := r.Column(pkField.FieldID).(*array.Int64).Value(i) ts := r.Column(common.TimeStampField).(*array.Int64).Value(i) - return segmentFilters[ri].Filtered(pk, uint64(ts)) + return !segmentFilters[ri].Filtered(pk, uint64(ts)) } case schemapb.DataType_VarChar: predicate = func(r storage.Record, ri, i int) bool { pk := r.Column(pkField.FieldID).(*array.String).Value(i) ts := r.Column(common.TimeStampField).(*array.Int64).Value(i) - return segmentFilters[ri].Filtered(pk, uint64(ts)) + return !segmentFilters[ri].Filtered(pk, uint64(ts)) } default: log.Warn("compaction only support int64 and varchar pk field") diff --git a/internal/datanode/compaction/segment_record_reader.go b/internal/datanode/compaction/segment_record_reader.go index 3e0897d0ceac3..7140d9f3eff46 100644 --- a/internal/datanode/compaction/segment_record_reader.go +++ b/internal/datanode/compaction/segment_record_reader.go @@ -21,10 +21,10 @@ func NewSegmentRecordReader(ctx context.Context, binlogPaths [][]string, binlogI if err != nil { return nil, err } - pos++ blobs := lo.Map(bytesArr, func(v []byte, i int) *storage.Blob { return &storage.Blob{Key: binlogPaths[pos][i], Value: v} }) + pos++ return blobs, nil }, } diff --git a/internal/storage/binlog_reader.go b/internal/storage/binlog_reader.go index ad364c3d751a1..56adf2990e359 100644 --- a/internal/storage/binlog_reader.go +++ b/internal/storage/binlog_reader.go @@ -106,6 +106,9 @@ func ReadDescriptorEvent(buffer io.Reader) (*descriptorEvent, error) { // Close closes the BinlogReader object. // It mainly calls the Close method of the internal events, reclaims resources, and marks itself as closed. func (reader *BinlogReader) Close() { + if reader == nil { + return + } if reader.isClose { return } diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index e8784481a65cf..f611c65ba3187 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -142,10 +142,16 @@ func (crr *CompositeBinlogRecordReader) Record() Record { func (crr *CompositeBinlogRecordReader) Close() error { if crr.brs != nil { for _, er := range crr.brs { - er.Close() + if er != nil { + er.Close() + } } + } + if crr.rrs != nil { for _, rr := range crr.rrs { - rr.Release() + if rr != nil { + rr.Release() + } } } crr.r = nil From 724d838e62531e10c4dc3502189a45e140528fbd Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Wed, 5 Feb 2025 12:20:03 +0800 Subject: [PATCH 6/6] fix UT fail Signed-off-by: Ted Xu --- internal/indexnode/task_stats.go | 19 +++++++------------ internal/storage/sort.go | 8 ++------ 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/internal/indexnode/task_stats.go b/internal/indexnode/task_stats.go index 65d0c63f29a2b..d29fc59eaea30 100644 --- a/internal/indexnode/task_stats.go +++ b/internal/indexnode/task_stats.go @@ -371,18 +371,14 @@ func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) { return nil, err } - var ( - remainingRowCount int64 // the number of remaining entities - expiredRowCount int64 // the number of expired entities - ) - var isValueValid func(r storage.Record, ri, i int) bool switch pkField.DataType { case schemapb.DataType_Int64: isValueValid = func(r storage.Record, ri, i int) bool { v := r.Column(pkFieldID).(*array.Int64).Value(i) - ts, ok := deletePKs[v] - if ok && uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) < ts { + deleteTs, ok := deletePKs[v] + ts := uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) + if ok && ts < deleteTs { return false } return !st.isExpiredEntity(ts) @@ -390,8 +386,9 @@ func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) { case schemapb.DataType_VarChar: isValueValid = func(r storage.Record, ri, i int) bool { v := r.Column(pkFieldID).(*array.String).Value(i) - ts, ok := deletePKs[v] - if ok && uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) < ts { + deleteTs, ok := deletePKs[v] + ts := uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) + if ok && ts < deleteTs { return false } return !st.isExpiredEntity(ts) @@ -425,9 +422,7 @@ func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) { } log.Info("download data success", - zap.Int64("old rows", numRows), - zap.Int64("remainingRowCount", remainingRowCount), - zap.Int64("expiredRowCount", expiredRowCount), + zap.Int64("numRows", numRows), zap.Duration("download binlogs elapse", downloadTimeCost), ) diff --git a/internal/storage/sort.go b/internal/storage/sort.go index b31fa65adc318..71ab6cb17ceea 100644 --- a/internal/storage/sort.go +++ b/internal/storage/sort.go @@ -49,12 +49,7 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r records = append(records, rec) for i := 0; i < rec.Len(); i++ { if predicate(rec, ri, i) { - numRows++ - if len(indices) > 0 && indices[len(indices)-1].ri == ri && indices[len(indices)-1].i+1 == i { - indices[len(indices)-1].i = i - } else { - indices = append(indices, &index{ri, i}) - } + indices = append(indices, &index{ri, i}) } } } else if err == io.EOF { @@ -90,6 +85,7 @@ func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r return rw.Write(rec) } for _, i := range indices { + numRows++ writeOne(i) }