diff --git a/common/db/buffered_bulk.go b/common/db/buffered_bulk.go index 3c41dc594..a3db2b186 100644 --- a/common/db/buffered_bulk.go +++ b/common/db/buffered_bulk.go @@ -9,6 +9,7 @@ package db import ( "context" "fmt" + "strings" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -26,6 +27,7 @@ const MAX_MESSAGE_SIZE_BYTES = 48000000 type BufferedBulkInserter struct { collection *mongo.Collection writeModels []mongo.WriteModel + docs []bson.D docLimit int docCount int byteCount int @@ -88,6 +90,7 @@ func (bb *BufferedBulkInserter) ResetBulk() { bb.writeModels = bb.writeModels[:0] bb.docCount = 0 bb.byteCount = 0 + bb.docs = bb.docs[:0] } // Insert adds a document to the buffer for bulk insertion. If the buffer becomes full, the bulk write is performed, returning @@ -98,6 +101,7 @@ func (bb *BufferedBulkInserter) Insert(doc interface{}) (*mongo.BulkWriteResult, return nil, fmt.Errorf("bson encoding error: %v", err) } + bb.docs = append(bb.docs, doc.(bson.D)) return bb.InsertRaw(rawBytes) } @@ -175,9 +179,131 @@ func (bb *BufferedBulkInserter) TryFlush() (*mongo.BulkWriteResult, error) { } func (bb *BufferedBulkInserter) flush() (*mongo.BulkWriteResult, error) { - if bb.docCount == 0 { - return nil, nil + + ctx := context.Background() + + if bb.docCount == 0 { + return nil, nil + } + res, bulkWriteErr := bb.collection.BulkWrite(ctx, bb.writeModels, bb.bulkWriteOpts) + if bulkWriteErr == nil { + return res, nil + } + + bulkWriteException, ok := bulkWriteErr.(mongo.BulkWriteException) + if !ok { + return res, bulkWriteErr + } + + var retryDocFilters []bson.D + + for _, we := range bulkWriteException.WriteErrors { + if we.Code == ErrDuplicateKeyCode { + var errDetails map[string]bson.Raw + bson.Unmarshal(we.WriteError.Raw, &errDetails) + var filter bson.D + bson.Unmarshal(errDetails["keyValue"], &filter) + + exists, err := checkDocumentExistence(ctx, bb.collection, filter) + if err != nil { + return nil, err + } + if !exists { + retryDocFilters = append(retryDocFilters, filter) + } else { + } + } + } + + for _, filter := range retryDocFilters { + for _, doc := range bb.docs { + var exists bool + var err error + if compareDocumentWithKeys(filter, doc) { + for range(3) { + _, err = bb.collection.InsertOne(ctx, doc) + if err == nil { + break + } + } + exists, err = checkDocumentExistence(ctx, bb.collection, filter) + if err != nil { + return nil, err + } + if exists { + break + } + } + if !exists { + return nil, fmt.Errorf("could not insert document %+v", doc) + } + } + } + + res.InsertedCount += int64(len(retryDocFilters)) + return res, bulkWriteErr +} + + +// extractValueByPath digs into a bson.D using a dotted path to retrieve the value +func extractValueByPath(doc bson.D, path string) (interface{}, bool) { + parts := strings.Split(path, ".") + var current interface{} = doc + for _, part := range parts { + switch curr := current.(type) { + case bson.D: + found := false + for _, elem := range curr { + if elem.Key == part { + current = elem.Value + found = true + break + } + } + if !found { + return nil, false + } + default: + return nil, false + } + } + return current, true +} + +// compareDocumentWithKeys checks if the key-value pairs in doc1 exist in doc2 +func compareDocumentWithKeys(doc1 bson.D, doc2 bson.D) bool { + for _, elem := range doc1 { + value, exists := extractValueByPath(doc2, elem.Key) + if !exists || value != elem.Value { + return false + } + } + return true +} + +func checkDocumentExistence(ctx context.Context, collection *mongo.Collection, document bson.D) (bool, error) { + findCmd := bson.D{ + {Key: "find", Value: collection.Name()}, + {Key: "filter", Value: document}, + {Key: "readConcern", Value: bson.D{{Key: "level", Value: "majority"}}}, + } + + db := collection.Database() + + var result bson.M + err := db.RunCommand(ctx, findCmd).Decode(&result) + if err != nil { + return false, err + } + + if cursor, ok := result["cursor"].(bson.M); ok { + if firstBatch, ok := cursor["firstBatch"].(bson.A); ok && len(firstBatch) > 0 { + return true, nil + } else { + return false, nil + } + } else { + return false, err } - return bb.collection.BulkWrite(context.Background(), bb.writeModels, bb.bulkWriteOpts) }