diff --git a/common/bsonutil/bsonutil.go b/common/bsonutil/bsonutil.go index af1fb144d..30cd5f3d3 100644 --- a/common/bsonutil/bsonutil.go +++ b/common/bsonutil/bsonutil.go @@ -477,14 +477,58 @@ func MtoD(m bson.M) bson.D { // but would return an error if it cannot be reversed by bson.UnmarshalExtJSON. // // It is preferred to be used in mongodump to avoid generating un-reversible ext JSON. -func MarshalExtJSONReversible(val interface{}, canonical bool, escapeHTML bool) ([]byte, error) { +func MarshalExtJSONReversible( + val interface{}, + canonical bool, + escapeHTML bool, +) ([]byte, error) { jsonBytes, err := bson.MarshalExtJSON(val, canonical, escapeHTML) if err != nil { return nil, err } + reversedVal := reflect.New(reflect.TypeOf(val)).Elem().Interface() if unmarshalErr := bson.UnmarshalExtJSON(jsonBytes, canonical, &reversedVal); unmarshalErr != nil { return nil, errors2.Wrap(unmarshalErr, "marshal is not reversible") } + + return jsonBytes, nil +} + +// MarshalExtJSONWithBSONRoundtripConsistency is a wrapper around bson.MarshalExtJSON +// which also validates that BSON objects that are marshaled to ExtJSON objects +// return a consistent BSON object when unmarshaled. +func MarshalExtJSONWithBSONRoundtripConsistency( + val interface{}, + canonical bool, + escapeHTML bool, +) ([]byte, error) { + jsonBytes, err := MarshalExtJSONReversible(val, canonical, escapeHTML) + if err != nil { + return nil, err + } + + originalBSON, err := bson.Marshal(val) + if err != nil { + return nil, fmt.Errorf("could not marshal into BSON") + } + + reversedVal := reflect.New(reflect.TypeOf(val)).Elem().Interface() + err = bson.UnmarshalExtJSON(jsonBytes, canonical, &reversedVal) + if err != nil { + return nil, err + } + + reversedBSON, err := bson.Marshal(reversedVal) + if err != nil { + return nil, fmt.Errorf("could not marshal into BSON") + } + + if !bytes.Equal(originalBSON, reversedBSON) { + return nil, fmt.Errorf( + "marshaling BSON to ExtJSON and back resulted in discrepancies", + ) + } + return jsonBytes, nil } diff --git a/common/bsonutil/bsonutil_test.go b/common/bsonutil/bsonutil_test.go index 6c958e621..c64c454ea 100644 --- a/common/bsonutil/bsonutil_test.go +++ b/common/bsonutil/bsonutil_test.go @@ -1,6 +1,7 @@ package bsonutil import ( + "math" "testing" "time" @@ -104,32 +105,32 @@ func TestMarshalExtJSONReversible(t *testing.T) { tests := []struct { val any - canonical bool reversible bool expectedJSON string }{ { bson.M{"field1": bson.M{"$date": 1257894000000}}, true, - true, `{"field1":{"$date":{"$numberLong":"1257894000000"}}}`, }, { bson.M{"field1": time.Unix(1257894000, 0)}, true, - true, `{"field1":{"$date":{"$numberLong":"1257894000000"}}}`, }, { bson.M{"field1": bson.M{"$date": "invalid"}}, - true, false, ``, }, } for _, test := range tests { - json, err := MarshalExtJSONReversible(test.val, test.canonical, false) + json, err := MarshalExtJSONReversible( + test.val, + true, /* canonical */ + false, /* escapeHTML */ + ) if !test.reversible { assert.ErrorContains(t, err, "marshal is not reversible") } else { @@ -138,3 +139,47 @@ func TestMarshalExtJSONReversible(t *testing.T) { assert.Equal(t, test.expectedJSON, string(json)) } } + +func TestMarshalExtJSONWithBSONRoundtripConsistency(t *testing.T) { + testtype.SkipUnlessTestType(t, testtype.UnitTestType) + + tests := []struct { + val any + consistentAfterRoundtripping bool + expectedJSON string + }{ + { + bson.M{"field1": bson.M{"grapes": int64(123)}}, + true, + `{"field1":{"grapes":{"$numberLong":"123"}}}`, + }, + { + bson.M{"field1": bson.M{"$date": 1257894000000}}, + false, + ``, + }, + { + bson.M{"field1": bson.M{"nanField": math.NaN()}}, + true, + `{"field1":{"nanField":{"$numberDouble":"NaN"}}}`, + }, + } + + for _, test := range tests { + json, err := MarshalExtJSONWithBSONRoundtripConsistency( + test.val, + true, /* canonical */ + false, /* escapeHTML */ + ) + if !test.consistentAfterRoundtripping { + assert.ErrorContains( + t, + err, + "marshaling BSON to ExtJSON and back resulted in discrepancies", + ) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedJSON, string(json)) + } +} diff --git a/common/db/buffered_bulk.go b/common/db/buffered_bulk.go index 3c41dc594..ebc241fd9 100644 --- a/common/db/buffered_bulk.go +++ b/common/db/buffered_bulk.go @@ -9,6 +9,7 @@ package db import ( "context" "fmt" + "strings" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -26,6 +27,7 @@ const MAX_MESSAGE_SIZE_BYTES = 48000000 type BufferedBulkInserter struct { collection *mongo.Collection writeModels []mongo.WriteModel + docs []bson.D docLimit int docCount int byteCount int @@ -88,16 +90,18 @@ func (bb *BufferedBulkInserter) ResetBulk() { bb.writeModels = bb.writeModels[:0] bb.docCount = 0 bb.byteCount = 0 + bb.docs = bb.docs[:0] } // Insert adds a document to the buffer for bulk insertion. If the buffer becomes full, the bulk write is performed, returning // any error that occurs. -func (bb *BufferedBulkInserter) Insert(doc interface{}) (*mongo.BulkWriteResult, error) { +func (bb *BufferedBulkInserter) Insert(doc bson.D) (*mongo.BulkWriteResult, error) { rawBytes, err := bson.Marshal(doc) if err != nil { return nil, fmt.Errorf("bson encoding error: %v", err) } + bb.docs = append(bb.docs, doc) return bb.InsertRaw(rawBytes) } @@ -175,9 +179,131 @@ func (bb *BufferedBulkInserter) TryFlush() (*mongo.BulkWriteResult, error) { } func (bb *BufferedBulkInserter) flush() (*mongo.BulkWriteResult, error) { - if bb.docCount == 0 { - return nil, nil + + ctx := context.Background() + + if bb.docCount == 0 { + return nil, nil + } + res, bulkWriteErr := bb.collection.BulkWrite(ctx, bb.writeModels, bb.bulkWriteOpts) + if bulkWriteErr == nil { + return res, nil + } + + bulkWriteException, ok := bulkWriteErr.(mongo.BulkWriteException) + if !ok { + return res, bulkWriteErr } + + var retryDocFilters []bson.D + + for _, we := range bulkWriteException.WriteErrors { + if we.Code == ErrDuplicateKeyCode { + var errDetails map[string]bson.Raw + bson.Unmarshal(we.WriteError.Raw, &errDetails) + var filter bson.D + bson.Unmarshal(errDetails["keyValue"], &filter) + + exists, err := checkDocumentExistence(ctx, bb.collection, filter) + if err != nil { + return nil, err + } + if !exists { + retryDocFilters = append(retryDocFilters, filter) + } else { + } + } + } + + for _, filter := range retryDocFilters { + for _, doc := range bb.docs { + var exists bool + var err error + if compareDocumentWithKeys(filter, doc) { + for range(3) { + _, err = bb.collection.InsertOne(ctx, doc) + if err == nil { + break + } + } + exists, err = checkDocumentExistence(ctx, bb.collection, filter) + if err != nil { + return nil, err + } + if exists { + break + } + } + if !exists { + return nil, fmt.Errorf("could not insert document %+v", doc) + } + } + } + + res.InsertedCount += int64(len(retryDocFilters)) + return res, bulkWriteErr +} - return bb.collection.BulkWrite(context.Background(), bb.writeModels, bb.bulkWriteOpts) + +// extractValueByPath digs into a bson.D using a dotted path to retrieve the value +func extractValueByPath(doc bson.D, path string) (interface{}, bool) { + parts := strings.Split(path, ".") + var current interface{} = doc + for _, part := range parts { + switch curr := current.(type) { + case bson.D: + found := false + for _, elem := range curr { + if elem.Key == part { + current = elem.Value + found = true + break + } + } + if !found { + return nil, false + } + default: + return nil, false + } + } + return current, true } + +// compareDocumentWithKeys checks if the key-value pairs in doc1 exist in doc2 +func compareDocumentWithKeys(doc1 bson.D, doc2 bson.D) bool { + for _, elem := range doc1 { + value, exists := extractValueByPath(doc2, elem.Key) + if !exists || value != elem.Value { + return false + } + } + return true +} + +func checkDocumentExistence(ctx context.Context, collection *mongo.Collection, document bson.D) (bool, error) { + findCmd := bson.D{ + {Key: "find", Value: collection.Name()}, + {Key: "filter", Value: document}, + {Key: "readConcern", Value: bson.D{{Key: "level", Value: "majority"}}}, + } + + db := collection.Database() + + var result bson.M + err := db.RunCommand(ctx, findCmd).Decode(&result) + if err != nil { + return false, err + } + + if cursor, ok := result["cursor"].(bson.M); ok { + if firstBatch, ok := cursor["firstBatch"].(bson.A); ok && len(firstBatch) > 0 { + return true, nil + } else { + return false, nil + } + } else { + return false, err + } + +} \ No newline at end of file diff --git a/common/db/buffered_bulk_test.go b/common/db/buffered_bulk_test.go index e1dc5d1f3..b90fd96a2 100644 --- a/common/db/buffered_bulk_test.go +++ b/common/db/buffered_bulk_test.go @@ -100,7 +100,7 @@ func TestBufferedBulkInserterInserts(t *testing.T) { errCnt := 0 for i := 0; i < 1000000; i++ { - result, err := bufBulk.Insert(bson.M{"_id": i}) + result, err := bufBulk.Insert(bson.D{{"_id", i}}) if err != nil { errCnt++ } diff --git a/common/db/db.go b/common/db/db.go index fbbfeca9c..34e6c352f 100644 --- a/common/db/db.go +++ b/common/db/db.go @@ -601,6 +601,15 @@ func CanIgnoreError(err error) bool { return ok case mongo.BulkWriteException: for _, writeErr := range mongoErr.WriteErrors { + + var decoded bson.M + err := bson.Unmarshal(writeErr.Raw, &decoded) + if err != nil { + return false + } + keyValue, _ := decoded["keyValue"].(bson.D) + fmt.Printf("TESTING THIS %+v\n", keyValue) + if _, ok := ignorableWriteErrorCodes[writeErr.Code]; !ok { return false } diff --git a/mongodump/metadata_dump.go b/mongodump/metadata_dump.go index 6bf0aa4d7..fdc0f7ab3 100644 --- a/mongodump/metadata_dump.go +++ b/mongodump/metadata_dump.go @@ -109,7 +109,7 @@ func (dump *MongoDump) dumpMetadata( } // Finally, we send the results to the writer as JSON bytes - jsonBytes, err := bsonutil.MarshalExtJSONReversible(meta, true, false) + jsonBytes, err := bsonutil.MarshalExtJSONWithBSONRoundtripConsistency(meta, true, false) if err != nil { return fmt.Errorf( "error marshaling metadata json for collection `%v`: %v", diff --git a/mongoimport/mongoimport.go b/mongoimport/mongoimport.go index b13a2a9d2..ddec6bed0 100644 --- a/mongoimport/mongoimport.go +++ b/mongoimport/mongoimport.go @@ -24,6 +24,8 @@ import ( "github.com/mongodb/mongo-tools/common/util" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" + mongoOpts "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readconcern" "gopkg.in/tomb.v2" ) @@ -454,7 +456,7 @@ func (imp *MongoImport) runInsertionWorker(readDocs chan bson.D) (err error) { if err != nil { return fmt.Errorf("error connecting to mongod: %v", err) } - collection := session.Database(imp.ToolOptions.DB).Collection(imp.ToolOptions.Collection) + collection := session.Database(imp.ToolOptions.DB).Collection(imp.ToolOptions.Collection, &mongoOpts.CollectionOptions{ReadConcern: readconcern.Majority()}) inserter := db.NewUnorderedBufferedBulkInserter(collection, imp.IngestOptions.BulkBufferSize). SetBypassDocumentValidation(imp.IngestOptions.BypassDocumentValidation). @@ -468,7 +470,7 @@ readLoop: if !alive { break readLoop } - err := imp.importDocument(inserter, document) + err := imp.importDocument(inserter, document) if db.FilterError(imp.IngestOptions.StopOnError, err) != nil { return err } @@ -481,6 +483,15 @@ readLoop: return db.FilterError(imp.IngestOptions.StopOnError, err) } +func checkMajority(ctx context.Context, collection *mongo.Collection, document bson.D) { + // readOpts := mongoOpts.FindOne(). + var result bson.M + err := collection.FindOne(context.Background(), document).Decode(&result) + if err == nil { + fmt.Printf("found one %s", result) + } +} + func (imp *MongoImport) updateCounts(result *mongo.BulkWriteResult, err error) { if result != nil { atomic.AddUint64( diff --git a/mongoimport/mongoimport_test.go b/mongoimport/mongoimport_test.go index 25b2266bb..d5194c159 100644 --- a/mongoimport/mongoimport_test.go +++ b/mongoimport/mongoimport_test.go @@ -932,7 +932,7 @@ func TestImportDocuments(t *testing.T) { } So(checkOnlyHasDocuments(imp.SessionProvider, expectedDocuments), ShouldBeNil) }) - Convey("an error should be thrown for CSV import on test data with "+ + Convey("an error should be thrown for CSV import on test data wisth "+ "duplicate _id if --stopOnError is set", func() { imp, err := NewMongoImport() So(err, ShouldBeNil) @@ -1311,6 +1311,7 @@ func TestImportDocuments(t *testing.T) { }) } + func nestedFieldsTestHelper(data string, expectedDocuments []bson.M, expectedErr error) func() { return func() { err := os.WriteFile(util.ToUniversalPath("./temp_test_data.csv"), []byte(data), 0644)