diff --git a/internal/datanode/compaction/merge_sort.go b/internal/datanode/compaction/merge_sort.go index 0758d660ec527..05f1fe9130fd4 100644 --- a/internal/datanode/compaction/merge_sort.go +++ b/internal/datanode/compaction/merge_sort.go @@ -1,20 +1,21 @@ package compaction import ( - "container/heap" "context" "fmt" - sio "io" "math" "time" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/flushcommon/io" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/proto/datapb" @@ -22,6 +23,24 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +type segmentWriterWrapper struct { + *MultiSegmentWriter +} + +var _ storage.RecordWriter = (*segmentWriterWrapper)(nil) + +func (w *segmentWriterWrapper) GetWrittenUncompressed() uint64 { + return 0 +} + +func (w *segmentWriterWrapper) Write(record storage.Record) error { + return w.MultiSegmentWriter.WriteRecord(record) +} + +func (w *segmentWriterWrapper) Close() error { + return nil +} + func mergeSortMultipleSegments(ctx context.Context, plan *datapb.CompactionPlan, collectionID, partitionID, maxRows int64, @@ -43,6 +62,7 @@ func mergeSortMultipleSegments(ctx context.Context, logIDAlloc := allocator.NewLocalAllocator(plan.GetBeginLogID(), math.MaxInt64) compAlloc := NewCompactionAllocator(segIDAlloc, logIDAlloc) mWriter := NewMultiSegmentWriter(binlogIO, compAlloc, plan, maxRows, partitionID, collectionID, bm25FieldIds) + writer := &segmentWriterWrapper{MultiSegmentWriter: mWriter} pkField, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema()) if err != nil { @@ -50,8 +70,7 @@ func mergeSortMultipleSegments(ctx context.Context, return nil, err } - // SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID()) - segmentReaders := make([]*SegmentDeserializeReader, len(binlogs)) + segmentReaders := make([]storage.RecordReader, len(binlogs)) segmentFilters := make([]*EntityFilter, len(binlogs)) for i, s := range binlogs { var binlogBatchCount int @@ -75,7 +94,7 @@ func mergeSortMultipleSegments(ctx context.Context, } binlogPaths[idx] = batchPaths } - segmentReaders[i] = NewSegmentDeserializeReader(ctx, binlogPaths, binlogIO, pkField.GetFieldID(), bm25FieldIds) + segmentReaders[i] = NewSegmentRecordReader(ctx, binlogPaths, binlogIO) deltalogPaths := make([]string, 0) for _, d := range s.GetDeltalogs() { for _, l := range d.GetBinlogs() { @@ -89,57 +108,26 @@ func mergeSortMultipleSegments(ctx context.Context, segmentFilters[i] = newEntityFilter(delta, collectionTtl, currentTime) } - advanceRow := func(i int) (*storage.Value, error) { - for { - v, err := segmentReaders[i].Next() - if err != nil { - return nil, err - } - - if segmentFilters[i].Filtered(v.PK.GetValue(), uint64(v.Timestamp)) { - continue - } - - return v, nil + var predicate func(r storage.Record, ri, i int) bool + switch pkField.DataType { + case schemapb.DataType_Int64: + predicate = func(r storage.Record, ri, i int) bool { + pk := r.Column(pkField.FieldID).(*array.Int64).Value(i) + ts := r.Column(common.TimeStampField).(*array.Int64).Value(i) + return !segmentFilters[ri].Filtered(pk, uint64(ts)) } - } - - pq := make(PriorityQueue, 0) - heap.Init(&pq) - - for i := range segmentReaders { - v, err := advanceRow(i) - if err != nil { - log.Warn("compact wrong, failed to advance row", zap.Error(err)) - return nil, err + case schemapb.DataType_VarChar: + predicate = func(r storage.Record, ri, i int) bool { + pk := r.Column(pkField.FieldID).(*array.String).Value(i) + ts := r.Column(common.TimeStampField).(*array.Int64).Value(i) + return !segmentFilters[ri].Filtered(pk, uint64(ts)) } - heap.Push(&pq, &PQItem{ - Value: v, - Index: i, - }) + default: + log.Warn("compaction only support int64 and varchar pk field") } - for pq.Len() > 0 { - smallest := heap.Pop(&pq).(*PQItem) - v := smallest.Value - - err := mWriter.Write(v) - if err != nil { - log.Warn("compact wrong, failed to writer row", zap.Error(err)) - return nil, err - } - - iv, err := advanceRow(smallest.Index) - if err != nil && err != sio.EOF { - return nil, err - } - if err == nil { - next := &PQItem{ - Value: iv, - Index: smallest.Index, - } - heap.Push(&pq, next) - } + if _, err = storage.MergeSort(segmentReaders, pkField.FieldID, writer, predicate); err != nil { + return nil, err } res, err := mWriter.Finish() diff --git a/internal/datanode/compaction/segment_reader_from_binlogs.go b/internal/datanode/compaction/segment_reader_from_binlogs.go deleted file mode 100644 index 5d5fc3636be0a..0000000000000 --- a/internal/datanode/compaction/segment_reader_from_binlogs.go +++ /dev/null @@ -1,86 +0,0 @@ -package compaction - -import ( - "context" - "io" - - "github.com/samber/lo" - "go.uber.org/zap" - - binlogIO "github.com/milvus-io/milvus/internal/flushcommon/io" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/log" -) - -type SegmentDeserializeReader struct { - ctx context.Context - binlogIO binlogIO.BinlogIO - reader *storage.DeserializeReader[*storage.Value] - - pos int - PKFieldID int64 - binlogPaths [][]string - binlogPathPos int - - bm25FieldIDs []int64 -} - -func NewSegmentDeserializeReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO, PKFieldID int64, bm25FieldIDs []int64) *SegmentDeserializeReader { - return &SegmentDeserializeReader{ - ctx: ctx, - binlogIO: binlogIO, - reader: nil, - pos: 0, - PKFieldID: PKFieldID, - binlogPaths: binlogPaths, - binlogPathPos: 0, - bm25FieldIDs: bm25FieldIDs, - } -} - -func (r *SegmentDeserializeReader) initDeserializeReader() error { - if r.binlogPathPos >= len(r.binlogPaths) { - return io.EOF - } - allValues, err := r.binlogIO.Download(r.ctx, r.binlogPaths[r.binlogPathPos]) - if err != nil { - log.Warn("compact wrong, fail to download insertLogs", zap.Error(err)) - return err - } - - blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob { - return &storage.Blob{Key: r.binlogPaths[r.binlogPathPos][i], Value: v} - }) - - r.reader, err = storage.NewBinlogDeserializeReader(blobs, r.PKFieldID) - if err != nil { - log.Warn("compact wrong, failed to new insert binlogs reader", zap.Error(err)) - return err - } - r.binlogPathPos++ - return nil -} - -func (r *SegmentDeserializeReader) Next() (*storage.Value, error) { - if r.reader == nil { - if err := r.initDeserializeReader(); err != nil { - return nil, err - } - } - if err := r.reader.Next(); err != nil { - if err == io.EOF { - r.reader.Close() - if err := r.initDeserializeReader(); err != nil { - return nil, err - } - err = r.reader.Next() - return r.reader.Value(), err - } - return nil, err - } - return r.reader.Value(), nil -} - -func (r *SegmentDeserializeReader) Close() { - r.reader.Close() -} diff --git a/internal/datanode/compaction/segment_record_reader.go b/internal/datanode/compaction/segment_record_reader.go new file mode 100644 index 0000000000000..7140d9f3eff46 --- /dev/null +++ b/internal/datanode/compaction/segment_record_reader.go @@ -0,0 +1,31 @@ +package compaction + +import ( + "context" + "io" + + "github.com/samber/lo" + + binlogIO "github.com/milvus-io/milvus/internal/flushcommon/io" + "github.com/milvus-io/milvus/internal/storage" +) + +func NewSegmentRecordReader(ctx context.Context, binlogPaths [][]string, binlogIO binlogIO.BinlogIO) storage.RecordReader { + pos := 0 + return &storage.CompositeBinlogRecordReader{ + BlobsReader: func() ([]*storage.Blob, error) { + if pos >= len(binlogPaths) { + return nil, io.EOF + } + bytesArr, err := binlogIO.Download(ctx, binlogPaths[pos]) + if err != nil { + return nil, err + } + blobs := lo.Map(bytesArr, func(v []byte, i int) *storage.Blob { + return &storage.Blob{Key: binlogPaths[pos][i], Value: v} + }) + pos++ + return blobs, nil + }, + } +} diff --git a/internal/indexnode/task_stats.go b/internal/indexnode/task_stats.go index 68c5409b32f57..d29fc59eaea30 100644 --- a/internal/indexnode/task_stats.go +++ b/internal/indexnode/task_stats.go @@ -19,22 +19,23 @@ package indexnode import ( "context" "fmt" - sio "io" - "sort" "strconv" "time" + "github.com/apache/arrow/go/v12/arrow/array" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" "google.golang.org/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/compaction" iter "github.com/milvus-io/milvus/internal/datanode/iterators" "github.com/milvus-io/milvus/internal/flushcommon/io" "github.com/milvus-io/milvus/internal/metastore/kv/binlog" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/proto/datapb" "github.com/milvus-io/milvus/pkg/proto/indexcgopb" @@ -155,139 +156,309 @@ func (st *statsTask) PreExecute(ctx context.Context) error { return nil } -func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, error) { - numRows := st.req.GetNumRows() +// segmentRecordWriter is a wrapper of SegmentWriter to implement RecordWriter interface +type segmentRecordWriter struct { + sw *compaction.SegmentWriter + binlogMaxSize uint64 + rootPath string + logID int64 + maxLogID int64 + binlogIO io.BinlogIO + ctx context.Context + numRows int64 + bm25FieldIds []int64 + + lastUploads []*conc.Future[any] + binlogs map[typeutil.UniqueID]*datapb.FieldBinlog + statslog *datapb.FieldBinlog + bm25statslog []*datapb.FieldBinlog +} - bm25FieldIds := compaction.GetBM25FieldIDs(st.req.GetSchema()) - writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, statsBatchSize, st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds) +var _ storage.RecordWriter = (*segmentRecordWriter)(nil) + +func (srw *segmentRecordWriter) Close() error { + if !srw.sw.FlushAndIsEmpty() { + if err := srw.upload(); err != nil { + return err + } + if err := srw.waitLastUpload(); err != nil { + return err + } + } + + statslog, err := srw.statSerializeWrite() if err != nil { - log.Ctx(ctx).Warn("sort segment wrong, unable to init segment writer", - zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err + log.Ctx(srw.ctx).Warn("stats wrong, failed to serialize write segment stats", + zap.Int64("remaining row count", srw.numRows), zap.Error(err)) + return err } + srw.statslog = statslog + srw.logID++ - var ( - flushBatchCount int // binlog batch count + if len(srw.bm25FieldIds) > 0 { + binlogNums, bm25StatsLogs, err := srw.bm25SerializeWrite() + if err != nil { + log.Ctx(srw.ctx).Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err)) + return err + } + srw.logID += binlogNums + srw.bm25statslog = bm25StatsLogs + } - allBinlogs = make(map[typeutil.UniqueID]*datapb.FieldBinlog) // All binlog meta of a segment - uploadFutures = make([]*conc.Future[any], 0) + return nil +} - downloadCost time.Duration - serWriteTimeCost time.Duration - sortTimeCost time.Duration - ) +func (srw *segmentRecordWriter) GetWrittenUncompressed() uint64 { + return srw.sw.WrittenMemorySize() +} - downloadStart := time.Now() - values, err := st.downloadData(ctx, numRows, writer.GetPkID(), bm25FieldIds) +func (srw *segmentRecordWriter) Write(r storage.Record) error { + err := srw.sw.WriteRecord(r) if err != nil { - log.Ctx(ctx).Warn("download data failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err + return err } - downloadCost = time.Since(downloadStart) - sortStart := time.Now() - sort.Slice(values, func(i, j int) bool { - return values[i].PK.LT(values[j].PK) - }) - sortTimeCost += time.Since(sortStart) + if srw.sw.IsFullWithBinlogMaxSize(srw.binlogMaxSize) { + return srw.upload() + } + return nil +} - for i, v := range values { - err := writer.Write(v) - if err != nil { - log.Ctx(ctx).Warn("write value wrong, failed to writer row", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err - } +func (srw *segmentRecordWriter) upload() error { + if err := srw.waitLastUpload(); err != nil { + return err + } + binlogNum, kvs, partialBinlogs, err := serializeWrite(srw.ctx, srw.rootPath, srw.logID, srw.sw) + if err != nil { + return err + } - if (i+1)%statsBatchSize == 0 && writer.IsFullWithBinlogMaxSize(st.req.GetBinlogMaxSize()) { - serWriteStart := time.Now() - binlogNum, kvs, partialBinlogs, err := serializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.req.GetStartLogID()+st.logIDOffset, writer) - if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to serialize writer", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err - } - serWriteTimeCost += time.Since(serWriteStart) - - uploadFutures = append(uploadFutures, st.binlogIO.AsyncUpload(ctx, kvs)...) - mergeFieldBinlogs(allBinlogs, partialBinlogs) - - flushBatchCount++ - st.logIDOffset += binlogNum - if st.req.GetStartLogID()+st.logIDOffset >= st.req.GetEndLogID() { - log.Ctx(ctx).Warn("binlog files too much, log is not enough", zap.Int64("taskID", st.req.GetTaskID()), - zap.Int64("binlog num", binlogNum), zap.Int64("startLogID", st.req.GetStartLogID()), - zap.Int64("endLogID", st.req.GetEndLogID()), zap.Int64("logIDOffset", st.logIDOffset)) - return nil, fmt.Errorf("binlog files too much, log is not enough") + srw.lastUploads = srw.binlogIO.AsyncUpload(srw.ctx, kvs) + if srw.binlogs == nil { + srw.binlogs = make(map[typeutil.UniqueID]*datapb.FieldBinlog) + } + mergeFieldBinlogs(srw.binlogs, partialBinlogs) + + srw.logID += binlogNum + if srw.logID > srw.maxLogID { + return fmt.Errorf("log id exausted") + } + return nil +} + +func (srw *segmentRecordWriter) waitLastUpload() error { + if len(srw.lastUploads) > 0 { + for _, future := range srw.lastUploads { + if _, err := future.Await(); err != nil { + return err } } } + return nil +} - if !writer.FlushAndIsEmpty() { - serWriteStart := time.Now() - binlogNum, kvs, partialBinlogs, err := serializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.req.GetStartLogID()+st.logIDOffset, writer) - if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to serialize writer", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) - return nil, err +func (srw *segmentRecordWriter) statSerializeWrite() (*datapb.FieldBinlog, error) { + ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(srw.ctx, "statslog serializeWrite") + defer span.End() + sblob, err := srw.sw.Finish() + if err != nil { + return nil, err + } + + key, _ := binlog.BuildLogPathWithRootPath(srw.rootPath, storage.StatsBinlog, + srw.sw.GetCollectionID(), srw.sw.GetPartitionID(), srw.sw.GetSegmentID(), srw.sw.GetPkID(), srw.logID) + kvs := map[string][]byte{key: sblob.GetValue()} + statFieldLog := &datapb.FieldBinlog{ + FieldID: srw.sw.GetPkID(), + Binlogs: []*datapb.Binlog{ + { + LogSize: int64(len(sblob.GetValue())), + MemorySize: int64(len(sblob.GetValue())), + LogPath: key, + EntriesNum: srw.numRows, + }, + }, + } + if err := srw.binlogIO.Upload(ctx, kvs); err != nil { + log.Ctx(ctx).Warn("failed to upload insert log", zap.Error(err)) + return nil, err + } + + return statFieldLog, nil +} + +func (srw *segmentRecordWriter) bm25SerializeWrite() (int64, []*datapb.FieldBinlog, error) { + ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(srw.ctx, "bm25log serializeWrite") + defer span.End() + writer := srw.sw + stats, err := writer.GetBm25StatsBlob() + if err != nil { + return 0, nil, err + } + + kvs := make(map[string][]byte) + binlogs := []*datapb.FieldBinlog{} + cnt := int64(0) + for fieldID, blob := range stats { + key, _ := binlog.BuildLogPathWithRootPath(srw.rootPath, storage.BM25Binlog, + writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, srw.logID) + kvs[key] = blob.GetValue() + fieldLog := &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{ + { + LogSize: int64(len(blob.GetValue())), + MemorySize: int64(len(blob.GetValue())), + LogPath: key, + EntriesNum: srw.numRows, + }, + }, } - serWriteTimeCost += time.Since(serWriteStart) - st.logIDOffset += binlogNum - uploadFutures = append(uploadFutures, st.binlogIO.AsyncUpload(ctx, kvs)...) - mergeFieldBinlogs(allBinlogs, partialBinlogs) - flushBatchCount++ + binlogs = append(binlogs, fieldLog) + srw.logID++ + cnt++ + } + + if err := srw.binlogIO.Upload(ctx, kvs); err != nil { + log.Ctx(ctx).Warn("failed to upload bm25 log", zap.Error(err)) + return 0, nil, err } - err = conc.AwaitAll(uploadFutures...) + return cnt, binlogs, nil +} + +func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) { + numRows := st.req.GetNumRows() + + bm25FieldIds := compaction.GetBM25FieldIDs(st.req.GetSchema()) + pkField, err := typeutil.GetPrimaryFieldSchema(st.req.GetSchema()) + if err != nil { + return nil, err + } + pkFieldID := pkField.FieldID + writer, err := compaction.NewSegmentWriter(st.req.GetSchema(), numRows, statsBatchSize, + st.req.GetTargetSegmentID(), st.req.GetPartitionID(), st.req.GetCollectionID(), bm25FieldIds) if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to upload kvs", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) + log.Ctx(ctx).Warn("sort segment wrong, unable to init segment writer", + zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) return nil, err } + srw := &segmentRecordWriter{ + sw: writer, + binlogMaxSize: st.req.GetBinlogMaxSize(), + rootPath: st.req.GetStorageConfig().GetRootPath(), + logID: st.req.StartLogID, + maxLogID: st.req.EndLogID, + binlogIO: st.binlogIO, + ctx: ctx, + numRows: st.req.NumRows, + bm25FieldIds: bm25FieldIds, + } + + log := log.Ctx(ctx).With( + zap.String("clusterID", st.req.GetClusterID()), + zap.Int64("taskID", st.req.GetTaskID()), + zap.Int64("collectionID", st.req.GetCollectionID()), + zap.Int64("partitionID", st.req.GetPartitionID()), + zap.Int64("segmentID", st.req.GetSegmentID()), + zap.Int64s("bm25Fields", bm25FieldIds), + ) - serWriteStart := time.Now() - binlogNums, sPath, err := statSerializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows) + deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs) if err != nil { - log.Ctx(ctx).Warn("stats wrong, failed to serialize write segment stats", zap.Int64("taskID", st.req.GetTaskID()), - zap.Int64("remaining row count", numRows), zap.Error(err)) + log.Warn("load deletePKs failed", zap.Error(err)) return nil, err } - serWriteTimeCost += time.Since(serWriteStart) - st.logIDOffset += binlogNums + var isValueValid func(r storage.Record, ri, i int) bool + switch pkField.DataType { + case schemapb.DataType_Int64: + isValueValid = func(r storage.Record, ri, i int) bool { + v := r.Column(pkFieldID).(*array.Int64).Value(i) + deleteTs, ok := deletePKs[v] + ts := uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) + if ok && ts < deleteTs { + return false + } + return !st.isExpiredEntity(ts) + } + case schemapb.DataType_VarChar: + isValueValid = func(r storage.Record, ri, i int) bool { + v := r.Column(pkFieldID).(*array.String).Value(i) + deleteTs, ok := deletePKs[v] + ts := uint64(r.Column(common.TimeStampField).(*array.Int64).Value(i)) + if ok && ts < deleteTs { + return false + } + return !st.isExpiredEntity(ts) + } + } + + downloadTimeCost := time.Duration(0) + + rrs := make([]storage.RecordReader, len(st.insertLogs)) - var bm25StatsLogs []*datapb.FieldBinlog - if len(bm25FieldIds) > 0 { - binlogNums, bm25StatsLogs, err = bm25SerializeWrite(ctx, st.req.GetStorageConfig().GetRootPath(), st.binlogIO, st.req.GetStartLogID()+st.logIDOffset, writer, numRows) + for i, paths := range st.insertLogs { + log := log.With(zap.Strings("paths", paths)) + downloadStart := time.Now() + allValues, err := st.binlogIO.Download(ctx, paths) if err != nil { - log.Ctx(ctx).Warn("compact wrong, failed to serialize write segment bm25 stats", zap.Error(err)) + log.Warn("download wrong, fail to download insertLogs", zap.Error(err)) return nil, err } - st.logIDOffset += binlogNums + downloadTimeCost += time.Since(downloadStart) + + blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob { + return &storage.Blob{Key: paths[i], Value: v} + }) - if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil { + rr, err := storage.NewCompositeBinlogRecordReader(blobs) + if err != nil { + log.Warn("downloadData wrong, failed to new insert binlogs reader", zap.Error(err)) return nil, err } + rrs[i] = rr } - totalElapse := st.tr.RecordSpan() + log.Info("download data success", + zap.Int64("numRows", numRows), + zap.Duration("download binlogs elapse", downloadTimeCost), + ) + + numValidRows, err := storage.Sort(rrs, writer.GetPkID(), srw, isValueValid) + if err != nil { + log.Warn("sort failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err)) + return nil, err + } + if err := srw.Close(); err != nil { + return nil, err + } - insertLogs := lo.Values(allBinlogs) + insertLogs := lo.Values(srw.binlogs) if err := binlog.CompressFieldBinlogs(insertLogs); err != nil { return nil, err } - statsLogs := []*datapb.FieldBinlog{sPath} + statsLogs := []*datapb.FieldBinlog{srw.statslog} if err := binlog.CompressFieldBinlogs(statsLogs); err != nil { return nil, err } + bm25StatsLogs := srw.bm25statslog + if err := binlog.CompressFieldBinlogs(bm25StatsLogs); err != nil { + return nil, err + } + st.node.storePKSortStatsResult(st.req.GetClusterID(), st.req.GetTaskID(), st.req.GetCollectionID(), st.req.GetPartitionID(), st.req.GetTargetSegmentID(), st.req.GetInsertChannel(), - int64(len(values)), insertLogs, statsLogs, bm25StatsLogs) + int64(numValidRows), insertLogs, statsLogs, bm25StatsLogs) - log.Ctx(ctx).Info("sort segment end", + log.Info("sort segment end", zap.String("clusterID", st.req.GetClusterID()), zap.Int64("taskID", st.req.GetTaskID()), zap.Int64("collectionID", st.req.GetCollectionID()), @@ -296,12 +467,7 @@ func (st *statsTask) sortSegment(ctx context.Context) ([]*datapb.FieldBinlog, er zap.String("subTaskType", st.req.GetSubJobType().String()), zap.Int64("target segmentID", st.req.GetTargetSegmentID()), zap.Int64("old rows", numRows), - zap.Int("valid rows", len(values)), - zap.Int("binlog batch count", flushBatchCount), - zap.Duration("download elapse", downloadCost), - zap.Duration("sort elapse", sortTimeCost), - zap.Duration("serWrite elapse", serWriteTimeCost), - zap.Duration("total elapse", totalElapse)) + zap.Int("valid rows", numValidRows)) return insertLogs, nil } @@ -313,7 +479,7 @@ func (st *statsTask) Execute(ctx context.Context) error { insertLogs := st.req.GetInsertLogs() var err error if st.req.GetSubJobType() == indexpb.StatsSubJob_Sort { - insertLogs, err = st.sortSegment(ctx) + insertLogs, err = st.sort(ctx) if err != nil { return err } @@ -350,99 +516,6 @@ func (st *statsTask) Reset() { st.node = nil } -func (st *statsTask) downloadData(ctx context.Context, numRows int64, PKFieldID int64, bm25FieldIds []int64) ([]*storage.Value, error) { - log := log.Ctx(ctx).With( - zap.String("clusterID", st.req.GetClusterID()), - zap.Int64("taskID", st.req.GetTaskID()), - zap.Int64("collectionID", st.req.GetCollectionID()), - zap.Int64("partitionID", st.req.GetPartitionID()), - zap.Int64("segmentID", st.req.GetSegmentID()), - zap.Int64s("bm25Fields", bm25FieldIds), - ) - - deletePKs, err := st.loadDeltalogs(ctx, st.deltaLogs) - if err != nil { - log.Warn("load deletePKs failed", zap.Error(err)) - return nil, err - } - - var ( - remainingRowCount int64 // the number of remaining entities - expiredRowCount int64 // the number of expired entities - ) - - isValueDeleted := func(v *storage.Value) bool { - ts, ok := deletePKs[v.PK.GetValue()] - // insert task and delete task has the same ts when upsert - // here should be < instead of <= - // to avoid the upsert data to be deleted after compact - if ok && uint64(v.Timestamp) < ts { - return true - } - return false - } - - downloadTimeCost := time.Duration(0) - - values := make([]*storage.Value, 0, numRows) - for _, paths := range st.insertLogs { - log := log.With(zap.Strings("paths", paths)) - downloadStart := time.Now() - allValues, err := st.binlogIO.Download(ctx, paths) - if err != nil { - log.Warn("download wrong, fail to download insertLogs", zap.Error(err)) - return nil, err - } - downloadTimeCost += time.Since(downloadStart) - - blobs := lo.Map(allValues, func(v []byte, i int) *storage.Blob { - return &storage.Blob{Key: paths[i], Value: v} - }) - - iter, err := storage.NewBinlogDeserializeReader(blobs, PKFieldID) - if err != nil { - log.Warn("downloadData wrong, failed to new insert binlogs reader", zap.Error(err)) - return nil, err - } - - for { - err := iter.Next() - if err != nil { - if err == sio.EOF { - break - } else { - log.Warn("downloadData wrong, failed to iter through data", zap.Error(err)) - iter.Close() - return nil, err - } - } - - v := iter.Value() - if isValueDeleted(v) { - continue - } - - // Filtering expired entity - if st.isExpiredEntity(typeutil.Timestamp(v.Timestamp)) { - expiredRowCount++ - continue - } - - values = append(values, iter.Value()) - remainingRowCount++ - } - iter.Close() - } - - log.Info("download data success", - zap.Int64("old rows", numRows), - zap.Int64("remainingRowCount", remainingRowCount), - zap.Int64("expiredRowCount", expiredRowCount), - zap.Duration("download binlogs elapse", downloadTimeCost), - ) - return values, nil -} - func (st *statsTask) loadDeltalogs(ctx context.Context, dpaths []string) (map[interface{}]typeutil.Timestamp, error) { st.tr.RecordSpan() ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "loadDeltalogs") @@ -545,74 +618,6 @@ func serializeWrite(ctx context.Context, rootPath string, startID int64, writer return } -func statSerializeWrite(ctx context.Context, rootPath string, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, *datapb.FieldBinlog, error) { - ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "statslog serializeWrite") - defer span.End() - sblob, err := writer.Finish() - if err != nil { - return 0, nil, err - } - - binlogNum := int64(1) - key, _ := binlog.BuildLogPathWithRootPath(rootPath, storage.StatsBinlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), writer.GetPkID(), startID) - kvs := map[string][]byte{key: sblob.GetValue()} - statFieldLog := &datapb.FieldBinlog{ - FieldID: writer.GetPkID(), - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(sblob.GetValue())), - MemorySize: int64(len(sblob.GetValue())), - LogPath: key, - EntriesNum: finalRowCount, - }, - }, - } - if err := io.Upload(ctx, kvs); err != nil { - log.Ctx(ctx).Warn("failed to upload insert log", zap.Error(err)) - return binlogNum, nil, err - } - - return binlogNum, statFieldLog, nil -} - -func bm25SerializeWrite(ctx context.Context, rootPath string, io io.BinlogIO, startID int64, writer *compaction.SegmentWriter, finalRowCount int64) (int64, []*datapb.FieldBinlog, error) { - ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "bm25log serializeWrite") - defer span.End() - stats, err := writer.GetBm25StatsBlob() - if err != nil { - return 0, nil, err - } - - kvs := make(map[string][]byte) - binlogs := []*datapb.FieldBinlog{} - cnt := int64(0) - for fieldID, blob := range stats { - key, _ := binlog.BuildLogPathWithRootPath(rootPath, storage.BM25Binlog, writer.GetCollectionID(), writer.GetPartitionID(), writer.GetSegmentID(), fieldID, startID+cnt) - kvs[key] = blob.GetValue() - fieldLog := &datapb.FieldBinlog{ - FieldID: fieldID, - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(blob.GetValue())), - MemorySize: int64(len(blob.GetValue())), - LogPath: key, - EntriesNum: finalRowCount, - }, - }, - } - - binlogs = append(binlogs, fieldLog) - cnt++ - } - - if err := io.Upload(ctx, kvs); err != nil { - log.Ctx(ctx).Warn("failed to upload bm25 log", zap.Error(err)) - return 0, nil, err - } - - return cnt, binlogs, nil -} - func ParseStorageConfig(s *indexpb.StorageConfig) (*indexcgopb.StorageConfig, error) { bs, err := proto.Marshal(s) if err != nil { diff --git a/internal/indexnode/task_stats_test.go b/internal/indexnode/task_stats_test.go index 6243301578bf4..e439313c05b9c 100644 --- a/internal/indexnode/task_stats_test.go +++ b/internal/indexnode/task_stats_test.go @@ -81,26 +81,6 @@ func (s *TaskStatsSuite) GenSegmentWriterWithBM25(magic int64) { s.segWriter = segWriter } -func (s *TaskStatsSuite) Testbm25SerializeWriteError() { - s.Run("normal case", func() { - s.schema = genCollectionSchemaWithBM25() - s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(nil).Once() - s.GenSegmentWriterWithBM25(0) - cnt, binlogs, err := bm25SerializeWrite(context.Background(), "root_path", s.mockBinlogIO, 0, s.segWriter, 1) - s.Require().NoError(err) - s.Equal(int64(1), cnt) - s.Equal(1, len(binlogs)) - }) - - s.Run("upload failed", func() { - s.schema = genCollectionSchemaWithBM25() - s.mockBinlogIO.EXPECT().Upload(mock.Anything, mock.Anything).Return(fmt.Errorf("mock error")).Once() - s.GenSegmentWriterWithBM25(0) - _, _, err := bm25SerializeWrite(context.Background(), "root_path", s.mockBinlogIO, 0, s.segWriter, 1) - s.Error(err) - }) -} - func (s *TaskStatsSuite) TestSortSegmentWithBM25() { s.Run("normal case", func() { s.schema = genCollectionSchemaWithBM25() @@ -130,10 +110,13 @@ func (s *TaskStatsSuite) TestSortSegmentWithBM25() { InsertLogs: lo.Values(fBinlogs), Schema: s.schema, NumRows: 1, + StartLogID: 0, + EndLogID: 5, + BinlogMaxSize: 64 * 1024 * 1024, }, node, s.mockBinlogIO) err = task.PreExecute(ctx) s.Require().NoError(err) - binlog, err := task.sortSegment(ctx) + binlog, err := task.sort(ctx) s.Require().NoError(err) s.Equal(5, len(binlog)) @@ -174,10 +157,13 @@ func (s *TaskStatsSuite) TestSortSegmentWithBM25() { InsertLogs: lo.Values(fBinlogs), Schema: s.schema, NumRows: 1, + StartLogID: 0, + EndLogID: 5, + BinlogMaxSize: 64 * 1024 * 1024, }, node, s.mockBinlogIO) err = task.PreExecute(ctx) s.Require().NoError(err) - _, err = task.sortSegment(ctx) + _, err = task.sort(ctx) s.Error(err) }) } diff --git a/internal/storage/binlog_iterator_test.go b/internal/storage/binlog_iterator_test.go index ebe4f478e3067..5365fc3662886 100644 --- a/internal/storage/binlog_iterator_test.go +++ b/internal/storage/binlog_iterator_test.go @@ -66,6 +66,10 @@ func generateTestSchema() *schemapb.CollectionSchema { } func generateTestData(num int) ([]*Blob, error) { + return generateTestDataWithSeed(1, num) +} + +func generateTestDataWithSeed(seed, num int) ([]*Blob, error) { insertCodec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ID: 1, Schema: generateTestSchema()}) var ( @@ -92,7 +96,7 @@ func generateTestData(num int) ([]*Blob, error) { field106 [][]byte ) - for i := 1; i <= num; i++ { + for i := seed; i < seed+num; i++ { field0 = append(field0, int64(i)) field1 = append(field1, int64(i)) field10 = append(field10, true) diff --git a/internal/storage/binlog_reader.go b/internal/storage/binlog_reader.go index ad364c3d751a1..56adf2990e359 100644 --- a/internal/storage/binlog_reader.go +++ b/internal/storage/binlog_reader.go @@ -106,6 +106,9 @@ func ReadDescriptorEvent(buffer io.Reader) (*descriptorEvent, error) { // Close closes the BinlogReader object. // It mainly calls the Close method of the internal events, reclaims resources, and marks itself as closed. func (reader *BinlogReader) Close() { + if reader == nil { + return + } if reader.isClose { return } diff --git a/internal/storage/serde.go b/internal/storage/serde.go index ae46ca8bfbb6b..7ecb14d5f5531 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -18,7 +18,6 @@ package storage import ( "encoding/binary" - "fmt" "io" "math" "sync" @@ -36,24 +35,23 @@ import ( ) type Record interface { - Schema() map[FieldID]schemapb.DataType - ArrowSchema() *arrow.Schema Column(i FieldID) arrow.Array Len() int Release() + Retain() Slice(start, end int) Record } type RecordReader interface { Next() error Record() Record - Close() + Close() error } type RecordWriter interface { Write(r Record) error GetWrittenUncompressed() uint64 - Close() + Close() error } type ( @@ -63,19 +61,19 @@ type ( // compositeRecord is a record being composed of multiple records, in which each only have 1 column type compositeRecord struct { - recs map[FieldID]arrow.Record - schema map[FieldID]schemapb.DataType + index map[FieldID]int16 + recs []arrow.Array } var _ Record = (*compositeRecord)(nil) func (r *compositeRecord) Column(i FieldID) arrow.Array { - return r.recs[i].Column(0) + return r.recs[r.index[i]] } func (r *compositeRecord) Len() int { for _, rec := range r.recs { - return rec.Column(0).Len() + return rec.Len() } return 0 } @@ -86,26 +84,21 @@ func (r *compositeRecord) Release() { } } -func (r *compositeRecord) Schema() map[FieldID]schemapb.DataType { - return r.schema -} - -func (r *compositeRecord) ArrowSchema() *arrow.Schema { - var fields []arrow.Field +func (r *compositeRecord) Retain() { for _, rec := range r.recs { - fields = append(fields, rec.Schema().Field(0)) + rec.Retain() } - return arrow.NewSchema(fields, nil) } func (r *compositeRecord) Slice(start, end int) Record { - slices := make(map[FieldID]arrow.Record) + slices := make([]arrow.Array, len(r.index)) for i, rec := range r.recs { - slices[i] = rec.NewSlice(int64(start), int64(end)) + d := array.NewSliceData(rec.Data(), int64(start), int64(end)) + slices[i] = array.MakeFromData(d) } return &compositeRecord{ - recs: slices, - schema: r.schema, + index: r.index, + recs: slices, } } @@ -543,28 +536,20 @@ func (deser *DeserializeReader[T]) Next() error { return nil } -func (deser *DeserializeReader[T]) NextRecord() (Record, error) { - if len(deser.values) != 0 { - return nil, errors.New("deserialize result is not empty") - } - - if err := deser.rr.Next(); err != nil { - return nil, err - } - return deser.rr.Record(), nil -} - func (deser *DeserializeReader[T]) Value() T { return deser.values[deser.pos] } -func (deser *DeserializeReader[T]) Close() { +func (deser *DeserializeReader[T]) Close() error { if deser.rec != nil { deser.rec.Release() } if deser.rr != nil { - deser.rr.Close() + if err := deser.rr.Close(); err != nil { + return err + } } + return nil } func NewDeserializeReader[T any](rr RecordReader, deserializer Deserializer[T]) *DeserializeReader[T] { @@ -578,22 +563,12 @@ var _ Record = (*selectiveRecord)(nil) // selectiveRecord is a Record that only contains a single field, reusing existing Record. type selectiveRecord struct { - r Record - selectedFieldId FieldID - - schema map[FieldID]schemapb.DataType -} - -func (r *selectiveRecord) Schema() map[FieldID]schemapb.DataType { - return r.schema -} - -func (r *selectiveRecord) ArrowSchema() *arrow.Schema { - return r.r.ArrowSchema() + r Record + fieldId FieldID } func (r *selectiveRecord) Column(i FieldID) arrow.Array { - if i == r.selectedFieldId { + if i == r.fieldId { return r.r.Column(i) } return nil @@ -607,6 +582,10 @@ func (r *selectiveRecord) Release() { // do nothing. } +func (r *selectiveRecord) Retain() { + // do nothing +} + func (r *selectiveRecord) Slice(start, end int) Record { panic("not implemented") } @@ -657,17 +636,10 @@ func calculateArraySize(a arrow.Array) int { return totalSize } -func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord { - dt, ok := r.Schema()[selectedFieldId] - if !ok { - return nil - } - schema := make(map[FieldID]schemapb.DataType, 1) - schema[selectedFieldId] = dt +func newSelectiveRecord(r Record, selectedFieldId FieldID) Record { return &selectiveRecord{ - r: r, - selectedFieldId: selectedFieldId, - schema: schema, + r: r, + fieldId: selectedFieldId, } } @@ -675,26 +647,19 @@ var _ RecordWriter = (*CompositeRecordWriter)(nil) type CompositeRecordWriter struct { writers map[FieldID]RecordWriter - - writtenUncompressed uint64 } func (crw *CompositeRecordWriter) GetWrittenUncompressed() uint64 { - return crw.writtenUncompressed + s := uint64(0) + for _, w := range crw.writers { + s += w.GetWrittenUncompressed() + } + return s } func (crw *CompositeRecordWriter) Write(r Record) error { - if len(r.Schema()) != len(crw.writers) { - return fmt.Errorf("schema length mismatch %d, expected %d", len(r.Schema()), len(crw.writers)) - } - - var bytes uint64 - for fid := range r.Schema() { - arr := r.Column(fid) - bytes += uint64(calculateArraySize(arr)) - } - crw.writtenUncompressed += bytes for fieldId, w := range crw.writers { + // TODO: if field is not exist, write sr := newSelectiveRecord(r, fieldId) if err := w.Write(sr); err != nil { return err @@ -703,14 +668,17 @@ func (crw *CompositeRecordWriter) Write(r Record) error { return nil } -func (crw *CompositeRecordWriter) Close() { +func (crw *CompositeRecordWriter) Close() error { if crw != nil { for _, w := range crw.writers { if w != nil { - w.Close() + if err := w.Close(); err != nil { + return err + } } } } + return nil } func NewCompositeRecordWriter(writers map[FieldID]RecordWriter) *CompositeRecordWriter { @@ -753,8 +721,8 @@ func (sfw *singleFieldRecordWriter) GetWrittenUncompressed() uint64 { return sfw.writtenUncompressed } -func (sfw *singleFieldRecordWriter) Close() { - sfw.fw.Close() +func (sfw *singleFieldRecordWriter) Close() error { + return sfw.fw.Close() } func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer, opts ...RecordWriterOptions) (*singleFieldRecordWriter, error) { @@ -804,8 +772,8 @@ func (mfw *multiFieldRecordWriter) GetWrittenUncompressed() uint64 { return mfw.writtenUncompressed } -func (mfw *multiFieldRecordWriter) Close() { - mfw.fw.Close() +func (mfw *multiFieldRecordWriter) Close() error { + return mfw.fw.Close() } func newMultiFieldRecordWriter(fieldIds []FieldID, fields []arrow.Field, writer io.Writer) (*multiFieldRecordWriter, error) { @@ -903,18 +871,13 @@ func NewSerializeRecordWriter[T any](rw RecordWriter, serializer Serializer[T], } type simpleArrowRecord struct { - r arrow.Record - schema map[FieldID]schemapb.DataType + r arrow.Record field2Col map[FieldID]int } var _ Record = (*simpleArrowRecord)(nil) -func (sr *simpleArrowRecord) Schema() map[FieldID]schemapb.DataType { - return sr.schema -} - func (sr *simpleArrowRecord) Column(i FieldID) arrow.Array { colIdx, ok := sr.field2Col[i] if !ok { @@ -931,19 +894,22 @@ func (sr *simpleArrowRecord) Release() { sr.r.Release() } +func (sr *simpleArrowRecord) Retain() { + sr.r.Retain() +} + func (sr *simpleArrowRecord) ArrowSchema() *arrow.Schema { return sr.r.Schema() } func (sr *simpleArrowRecord) Slice(start, end int) Record { s := sr.r.NewSlice(int64(start), int64(end)) - return newSimpleArrowRecord(s, sr.schema, sr.field2Col) + return newSimpleArrowRecord(s, sr.field2Col) } -func newSimpleArrowRecord(r arrow.Record, schema map[FieldID]schemapb.DataType, field2Col map[FieldID]int) *simpleArrowRecord { +func newSimpleArrowRecord(r arrow.Record, field2Col map[FieldID]int) *simpleArrowRecord { return &simpleArrowRecord{ r: r, - schema: schema, field2Col: field2Col, } } diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 80f22672c0c77..f611c65ba3187 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -39,39 +39,50 @@ import ( var _ RecordReader = (*CompositeBinlogRecordReader)(nil) +// ChunkedBlobsReader returns a chunk composed of blobs, or io.EOF if no more data +type ChunkedBlobsReader func() ([]*Blob, error) + type CompositeBinlogRecordReader struct { - blobs [][]*Blob + BlobsReader ChunkedBlobsReader - blobPos int - rrs []array.RecordReader - closers []func() - fields []FieldID + brs []*BinlogReader + rrs []array.RecordReader - r compositeRecord + schema map[FieldID]schemapb.DataType + index map[FieldID]int16 + r *compositeRecord } func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { - if crr.closers != nil { - for _, close := range crr.closers { - if close != nil { - close() - } + if crr.brs != nil { + for _, er := range crr.brs { + er.Close() + } + for _, rr := range crr.rrs { + rr.Release() } } - crr.blobPos++ - if crr.blobPos >= len(crr.blobs[0]) { - return io.EOF + + blobs, err := crr.BlobsReader() + if err != nil { + return err + } + + if crr.rrs == nil { + crr.rrs = make([]array.RecordReader, len(blobs)) + crr.brs = make([]*BinlogReader, len(blobs)) + crr.schema = make(map[FieldID]schemapb.DataType) + crr.index = make(map[FieldID]int16, len(blobs)) } - for i, b := range crr.blobs { - reader, err := NewBinlogReader(b[crr.blobPos].Value) + for i, b := range blobs { + reader, err := NewBinlogReader(b.Value) if err != nil { return err } - crr.fields[i] = reader.FieldID // TODO: assert schema being the same in every blobs - crr.r.schema[reader.FieldID] = reader.PayloadDataType + crr.schema[reader.FieldID] = reader.PayloadDataType er, err := reader.NextEventReader() if err != nil { return err @@ -81,40 +92,30 @@ func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { return err } crr.rrs[i] = rr - crr.closers[i] = func() { - rr.Release() - er.Close() - reader.Close() - } + crr.index[reader.FieldID] = int16(i) + crr.brs[i] = reader } return nil } func (crr *CompositeBinlogRecordReader) Next() error { if crr.rrs == nil { - if crr.blobs == nil || len(crr.blobs) == 0 { - return io.EOF - } - crr.rrs = make([]array.RecordReader, len(crr.blobs)) - crr.closers = make([]func(), len(crr.blobs)) - crr.blobPos = -1 - crr.fields = make([]FieldID, len(crr.rrs)) - crr.r = compositeRecord{ - recs: make(map[FieldID]arrow.Record, len(crr.rrs)), - schema: make(map[FieldID]schemapb.DataType, len(crr.rrs)), - } if err := crr.iterateNextBatch(); err != nil { return err } } composeRecord := func() bool { + recs := make([]arrow.Array, len(crr.rrs)) for i, rr := range crr.rrs { if ok := rr.Next(); !ok { return false } - // compose record - crr.r.recs[crr.fields[i]] = rr.Record() + recs[i] = rr.Record().Column(0) + } + crr.r = &compositeRecord{ + index: crr.index, + recs: recs, } return true } @@ -135,15 +136,26 @@ func (crr *CompositeBinlogRecordReader) Next() error { } func (crr *CompositeBinlogRecordReader) Record() Record { - return &crr.r + return crr.r } -func (crr *CompositeBinlogRecordReader) Close() { - for _, close := range crr.closers { - if close != nil { - close() +func (crr *CompositeBinlogRecordReader) Close() error { + if crr.brs != nil { + for _, er := range crr.brs { + if er != nil { + er.Close() + } } } + if crr.rrs != nil { + for _, rr := range crr.rrs { + if rr != nil { + rr.Release() + } + } + } + crr.r = nil + return nil } func parseBlobKey(blobKey string) (colId FieldID, logId UniqueID) { @@ -177,8 +189,19 @@ func NewCompositeBinlogRecordReader(blobs []*Blob) (*CompositeBinlogRecordReader }) sortedBlobs = append(sortedBlobs, blobsForField) } + chunkPos := 0 return &CompositeBinlogRecordReader{ - blobs: sortedBlobs, + BlobsReader: func() ([]*Blob, error) { + if len(sortedBlobs) == 0 || chunkPos >= len(sortedBlobs[0]) { + return nil, io.EOF + } + blobs := make([]*Blob, len(sortedBlobs)) + for fieldPos := range blobs { + blobs[fieldPos] = sortedBlobs[fieldPos][chunkPos] + } + chunkPos++ + return blobs, nil + }, }, nil } @@ -189,17 +212,18 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize } return NewDeserializeReader(reader, func(r Record, v []*Value) error { + schema := reader.schema // Note: the return value `Value` is reused. for i := 0; i < r.Len(); i++ { value := v[i] if value == nil { value = &Value{} - value.Value = make(map[FieldID]interface{}, len(r.Schema())) + value.Value = make(map[FieldID]interface{}, len(schema)) v[i] = value } m := value.Value.(map[FieldID]interface{}) - for j, dt := range r.Schema() { + for j, dt := range schema { if r.Column(j).IsNull(i) { m[j] = nil } else { @@ -219,7 +243,7 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize value.ID = rowID value.Timestamp = m[common.TimeStampField].(int64) - pk, err := GenPrimaryKeyByRawData(m[PKfieldID], r.Schema()[PKfieldID]) + pk, err := GenPrimaryKeyByRawData(m[PKfieldID], schema[PKfieldID]) if err != nil { return err } @@ -239,7 +263,7 @@ func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReader[*DeleteLog], e } return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error { var fid FieldID // The only fid from delete file - for k := range r.Schema() { + for k := range reader.schema { fid = k break } @@ -391,7 +415,7 @@ func ValueSerializer(v []*Value, fieldSchema []*schemapb.FieldSchema) (Record, e field2Col[fid] = i i++ } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), nil + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), field2Col), nil } func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, segmentID UniqueID, @@ -529,10 +553,7 @@ func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int field2Col := map[FieldID]int{ 0: 0, } - schema := map[FieldID]schemapb.DataType{ - 0: schemapb.DataType_String, - } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), schema, field2Col), nil + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), field2Col), nil }, batchSize), nil } @@ -583,12 +604,11 @@ func (crr *simpleArrowRecordReader) iterateNextBatch() error { func (crr *simpleArrowRecordReader) Next() error { if crr.rr == nil { - if crr.blobs == nil || len(crr.blobs) == 0 { + if len(crr.blobs) == 0 { return io.EOF } crr.blobPos = -1 crr.r = simpleArrowRecord{ - schema: make(map[FieldID]schemapb.DataType), field2Col: make(map[FieldID]int), } if err := crr.iterateNextBatch(); err != nil { @@ -623,10 +643,11 @@ func (crr *simpleArrowRecordReader) Record() Record { return &crr.r } -func (crr *simpleArrowRecordReader) Close() { +func (crr *simpleArrowRecordReader) Close() error { if crr.closer != nil { crr.closer() } + return nil } func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error) { @@ -781,11 +802,7 @@ func newDeltalogMultiFieldWriter(eventWriter *MultiFieldDeltalogStreamWriter, ba common.RowIDField: 0, common.TimeStampField: 1, } - schema := map[FieldID]schemapb.DataType{ - common.RowIDField: pkType, - common.TimeStampField: schemapb.DataType_Int64, - } - return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), schema, field2Col), nil + return newSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), field2Col), nil }, batchSize), nil } diff --git a/internal/storage/serde_events_test.go b/internal/storage/serde_events_test.go index d97c17fd5e1f9..bb252395bfc69 100644 --- a/internal/storage/serde_events_test.go +++ b/internal/storage/serde_events_test.go @@ -92,7 +92,7 @@ func TestBinlogStreamWriter(t *testing.T) { []arrow.Array{arr}, int64(size), ) - r := newSimpleArrowRecord(ar, map[FieldID]schemapb.DataType{1: schemapb.DataType_Bool}, map[FieldID]int{1: 0}) + r := newSimpleArrowRecord(ar, map[FieldID]int{1: 0}) defer r.Release() err = rw.Write(r) assert.NoError(t, err) diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index f31ffe8b49aff..df23ab229e25b 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -30,6 +30,25 @@ import ( "github.com/milvus-io/milvus/pkg/common" ) +type MockRecordWriter struct { + writefn func(Record) error + closefn func() error +} + +var _ RecordWriter = (*MockRecordWriter)(nil) + +func (w *MockRecordWriter) Write(record Record) error { + return w.writefn(record) +} + +func (w *MockRecordWriter) Close() error { + return w.closefn() +} + +func (w *MockRecordWriter) GetWrittenUncompressed() uint64 { + return 0 +} + func TestSerDe(t *testing.T) { type args struct { dt schemapb.DataType @@ -101,37 +120,6 @@ func TestSerDe(t *testing.T) { } } -func TestArrowSchema(t *testing.T) { - fields := []arrow.Field{{Name: "1", Type: arrow.BinaryTypes.String, Nullable: true}} - builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String) - builder.AppendValueFromString("1") - record := array.NewRecord(arrow.NewSchema(fields, nil), []arrow.Array{builder.NewArray()}, 1) - t.Run("test composite record", func(t *testing.T) { - cr := &compositeRecord{ - recs: make(map[FieldID]arrow.Record, 1), - schema: make(map[FieldID]schemapb.DataType, 1), - } - cr.recs[0] = record - cr.schema[0] = schemapb.DataType_String - expected := arrow.NewSchema(fields, nil) - assert.Equal(t, expected, cr.ArrowSchema()) - }) - - t.Run("test simple arrow record", func(t *testing.T) { - cr := &simpleArrowRecord{ - r: record, - schema: make(map[FieldID]schemapb.DataType, 1), - field2Col: make(map[FieldID]int, 1), - } - cr.schema[0] = schemapb.DataType_String - expected := arrow.NewSchema(fields, nil) - assert.Equal(t, expected, cr.ArrowSchema()) - - sr := newSelectiveRecord(cr, 0) - assert.Equal(t, expected, sr.ArrowSchema()) - }) -} - func BenchmarkDeserializeReader(b *testing.B) { len := 1000000 blobs, err := generateTestData(len) diff --git a/internal/storage/sort.go b/internal/storage/sort.go new file mode 100644 index 0000000000000..71ab6cb17ceea --- /dev/null +++ b/internal/storage/sort.go @@ -0,0 +1,241 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "container/heap" + "io" + "sort" + + "github.com/apache/arrow/go/v12/arrow/array" +) + +func Sort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) { + records := make([]Record, 0) + + type index struct { + ri int + i int + } + indices := make([]*index, 0) + + defer func() { + for _, r := range records { + r.Release() + } + }() + + for _, r := range rr { + for { + err := r.Next() + if err == nil { + rec := r.Record() + rec.Retain() + ri := len(records) + records = append(records, rec) + for i := 0; i < rec.Len(); i++ { + if predicate(rec, ri, i) { + indices = append(indices, &index{ri, i}) + } + } + } else if err == io.EOF { + break + } else { + return 0, err + } + } + } + + if len(records) == 0 { + return 0, nil + } + + switch records[0].Column(pkField).(type) { + case *array.Int64: + sort.Slice(indices, func(i, j int) bool { + pki := records[indices[i].ri].Column(pkField).(*array.Int64).Value(indices[i].i) + pkj := records[indices[j].ri].Column(pkField).(*array.Int64).Value(indices[j].i) + return pki < pkj + }) + case *array.String: + sort.Slice(indices, func(i, j int) bool { + pki := records[indices[i].ri].Column(pkField).(*array.String).Value(indices[i].i) + pkj := records[indices[j].ri].Column(pkField).(*array.String).Value(indices[j].i) + return pki < pkj + }) + } + + writeOne := func(i *index) error { + rec := records[i.ri].Slice(i.i, i.i+1) + defer rec.Release() + return rw.Write(rec) + } + for _, i := range indices { + numRows++ + writeOne(i) + } + + return numRows, nil +} + +// A PriorityQueue implements heap.Interface and holds Items. +type PriorityQueue[T any] struct { + items []*T + less func(x, y *T) bool +} + +var _ heap.Interface = (*PriorityQueue[any])(nil) + +func (pq PriorityQueue[T]) Len() int { return len(pq.items) } + +func (pq PriorityQueue[T]) Less(i, j int) bool { + return pq.less(pq.items[i], pq.items[j]) +} + +func (pq PriorityQueue[T]) Swap(i, j int) { + pq.items[i], pq.items[j] = pq.items[j], pq.items[i] +} + +func (pq *PriorityQueue[T]) Push(x any) { + pq.items = append(pq.items, x.(*T)) +} + +func (pq *PriorityQueue[T]) Pop() any { + old := pq.items + n := len(old) + x := old[n-1] + old[n-1] = nil + pq.items = old[0 : n-1] + return x +} + +func (pq *PriorityQueue[T]) Enqueue(x *T) { + heap.Push(pq, x) +} + +func (pq *PriorityQueue[T]) Dequeue() *T { + return heap.Pop(pq).(*T) +} + +func NewPriorityQueue[T any](less func(x, y *T) bool) *PriorityQueue[T] { + pq := PriorityQueue[T]{ + items: make([]*T, 0), + less: less, + } + heap.Init(&pq) + return &pq +} + +func MergeSort(rr []RecordReader, pkField FieldID, rw RecordWriter, predicate func(r Record, ri, i int) bool) (numRows int, err error) { + type index struct { + ri int + i int + } + + advanceRecord := func(r RecordReader) (Record, error) { + err := r.Next() + if err != nil { + return nil, err + } + return r.Record(), nil + } + + recs := make([]Record, len(rr)) + for i, r := range rr { + rec, err := advanceRecord(r) + if err == io.EOF { + recs[i] = nil + continue + } + if err != nil { + return 0, err + } + recs[i] = rec + } + + var pq *PriorityQueue[index] + switch recs[0].Column(pkField).(type) { + case *array.Int64: + pq = NewPriorityQueue[index](func(x, y *index) bool { + return rr[x.ri].Record().Column(pkField).(*array.Int64).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.Int64).Value(y.i) + }) + case *array.String: + pq = NewPriorityQueue[index](func(x, y *index) bool { + return rr[x.ri].Record().Column(pkField).(*array.String).Value(x.i) < rr[y.ri].Record().Column(pkField).(*array.String).Value(y.i) + }) + } + + enqueueAll := func(ri int, r Record) { + for j := 0; j < r.Len(); j++ { + if predicate(r, ri, j) { + pq.Enqueue(&index{ + ri: ri, + i: j, + }) + numRows++ + } + } + } + + for i, v := range recs { + if v != nil { + enqueueAll(i, v) + } + } + + ri, istart, iend := -1, -1, -1 + for pq.Len() > 0 { + idx := pq.Dequeue() + if ri == idx.ri { + // record end of cache, do nothing + iend = idx.i + 1 + } else { + if ri != -1 { + // record changed, write old one and reset + sr := rr[ri].Record().Slice(istart, iend) + err := rw.Write(sr) + sr.Release() + if err != nil { + return 0, err + } + } + ri = idx.ri + istart = idx.i + iend = idx.i + 1 + } + + // If poped idx reaches end of segment, invalidate cache and advance to next segment + if idx.i == rr[idx.ri].Record().Len()-1 { + sr := rr[ri].Record().Slice(istart, iend) + err := rw.Write(sr) + sr.Release() + if err != nil { + return 0, err + } + ri, istart, iend = -1, -1, -1 + rec, err := advanceRecord(rr[idx.ri]) + if err == io.EOF { + continue + } + if err != nil { + return 0, err + } + enqueueAll(idx.ri, rec) + } + } + return numRows, nil +} diff --git a/internal/storage/sort_test.go b/internal/storage/sort_test.go new file mode 100644 index 0000000000000..bee5c4f2dcd0a --- /dev/null +++ b/internal/storage/sort_test.go @@ -0,0 +1,162 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "testing" + + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/common" +) + +func TestSort(t *testing.T) { + getReaders := func() []RecordReader { + blobs, err := generateTestDataWithSeed(10, 3) + assert.NoError(t, err) + reader10, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + blobs, err = generateTestDataWithSeed(20, 3) + assert.NoError(t, err) + reader20, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + rr := []RecordReader{reader20, reader10} + return rr + } + + lastPK := int64(-1) + rw := &MockRecordWriter{ + writefn: func(r Record) error { + pk := r.Column(common.RowIDField).(*array.Int64).Value(0) + assert.Greater(t, pk, lastPK) + lastPK = pk + return nil + }, + + closefn: func() error { + lastPK = int64(-1) + return nil + }, + } + + t.Run("sort", func(t *testing.T) { + gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, 6, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) + + t.Run("sort with predicate", func(t *testing.T) { + gotNumRows, err := Sort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { + pk := r.Column(common.RowIDField).(*array.Int64).Value(i) + return pk >= 20 + }) + assert.NoError(t, err) + assert.Equal(t, 3, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) +} + +func TestMergeSort(t *testing.T) { + getReaders := func() []RecordReader { + blobs, err := generateTestDataWithSeed(10, 3) + assert.NoError(t, err) + reader10, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + blobs, err = generateTestDataWithSeed(20, 3) + assert.NoError(t, err) + reader20, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(t, err) + rr := []RecordReader{reader20, reader10} + return rr + } + + lastPK := int64(-1) + rw := &MockRecordWriter{ + writefn: func(r Record) error { + pk := r.Column(common.RowIDField).(*array.Int64).Value(0) + assert.Greater(t, pk, lastPK) + lastPK = pk + return nil + }, + + closefn: func() error { + lastPK = int64(-1) + return nil + }, + } + + t.Run("merge sort", func(t *testing.T) { + gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { + return true + }) + assert.NoError(t, err) + assert.Equal(t, 6, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) + + t.Run("merge sort with predicate", func(t *testing.T) { + gotNumRows, err := MergeSort(getReaders(), common.RowIDField, rw, func(r Record, ri, i int) bool { + pk := r.Column(common.RowIDField).(*array.Int64).Value(i) + return pk >= 20 + }) + assert.NoError(t, err) + assert.Equal(t, 3, gotNumRows) + err = rw.Close() + assert.NoError(t, err) + }) +} + +// Benchmark sort +func BenchmarkSort(b *testing.B) { + batch := 500000 + blobs, err := generateTestDataWithSeed(batch, batch) + assert.NoError(b, err) + reader10, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(b, err) + blobs, err = generateTestDataWithSeed(batch*2+1, batch) + assert.NoError(b, err) + reader20, err := NewCompositeBinlogRecordReader(blobs) + assert.NoError(b, err) + rr := []RecordReader{reader20, reader10} + + rw := &MockRecordWriter{ + writefn: func(r Record) error { + return nil + }, + + closefn: func() error { + return nil + }, + } + + b.ResetTimer() + + b.Run("sort", func(b *testing.B) { + for i := 0; i < b.N; i++ { + Sort(rr, common.RowIDField, rw, func(r Record, ri, i int) bool { + return true + }) + } + }) +}