diff --git a/connectors/cosmos/conn.go b/connectors/cosmos/conn.go index e6ce7956..86fd4db3 100644 --- a/connectors/cosmos/conn.go +++ b/connectors/cosmos/conn.go @@ -279,6 +279,7 @@ func convertChangeStreamEventToUpdate(change bson.M) (*adiomv1.Update, error) { Id: []*adiomv1.BsonValue{{ Data: idVal, Type: uint32(idType), + Name: "_id", }}, Type: adiomv1.UpdateType_UPDATE_TYPE_DELETE, }, nil @@ -299,6 +300,7 @@ func convertChangeStreamEventToUpdate(change bson.M) (*adiomv1.Update, error) { Id: []*adiomv1.BsonValue{{ Data: idVal, Type: uint32(idType), + Name: "_id", }}, Type: adiomv1.UpdateType_UPDATE_TYPE_UPDATE, Data: fullDocumentRaw, @@ -333,6 +335,7 @@ func checkForDeletes(ctx context.Context, client *mongo.Client, witnessClient *m Id: []*adiomv1.BsonValue{{ Data: idVal, Type: uint32(idType), + Name: "_id", }}, Type: adiomv1.UpdateType_UPDATE_TYPE_DELETE, }) diff --git a/connectors/mongo/conn.go b/connectors/mongo/conn.go index bcc52513..99435e11 100644 --- a/connectors/mongo/conn.go +++ b/connectors/mongo/conn.go @@ -16,10 +16,10 @@ import ( "time" "connectrpc.com/connect" + "github.com/adiom-data/dsync/connectors/util" adiomv1 "github.com/adiom-data/dsync/gen/adiom/v1" "github.com/adiom-data/dsync/gen/adiom/v1/adiomv1connect" "github.com/adiom-data/dsync/protocol/iface" - "github.com/cespare/xxhash" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" @@ -41,6 +41,7 @@ type ConnectorSettings struct { MaxPageSize int PerNamespaceStreams bool SkipBatchOverwrite bool + FullDocumentKey bool Query string // query filter, as a v2 Extended JSON string, e.g., '{\"x\":{\"$gt\":1}}'" } @@ -611,75 +612,61 @@ func createChangeStreamNamespaceFilterFromNamespaces(namespaces []iface.Namespac return bson.D{{"$or", filters}} } -func convertChangeStreamEventToUpdate(change bson.M) (*adiomv1.Update, error) { +func convertChangeStreamEventToUpdate(change MongoUpdate, fullDocumentKey bool) (*adiomv1.Update, error) { // slog.Debug(fmt.Sprintf("Converting change stream event %v", change)) - optype := change["operationType"].(string) + optype := change.OperationType var update *adiomv1.Update - switch optype { - case "insert": - // get the id of the document that was inserted - id := change["documentKey"].(bson.M)["_id"] - // convert id to raw bson - idType, idVal, err := bson.MarshalValue(id) + var id []*adiomv1.BsonValue + for _, k := range change.DocumentKey { + if !fullDocumentKey && k.Key != "_id" { + continue + } + idType, idVal, err := bson.MarshalValue(k.Value) if err != nil { - return nil, fmt.Errorf("failed to marshal _id: %v", err) + return nil, fmt.Errorf("failed to marshal %v: %v", k.Key, err) } - fullDocument := change["fullDocument"].(bson.M) + id = append(id, &adiomv1.BsonValue{ + Data: idVal, + Type: uint32(idType), + Name: k.Key, + }) + } + + switch optype { + case "insert": + fullDocument := change.FullDocument // convert fulldocument to BSON.Raw fullDocumentRaw, err := bson.Marshal(fullDocument) if err != nil { return nil, fmt.Errorf("failed to marshal full document: %v", err) } update = &adiomv1.Update{ - Id: []*adiomv1.BsonValue{{ - Data: idVal, - Type: uint32(idType), - }}, + Id: id, Type: adiomv1.UpdateType_UPDATE_TYPE_INSERT, Data: fullDocumentRaw, } case "update", "replace": - // get the id of the document that was changed - id := change["documentKey"].(bson.M)["_id"] - // convert id to raw bson - idType, idVal, err := bson.MarshalValue(id) - if err != nil { - return nil, fmt.Errorf("failed to marshal _id: %v", err) - } // get the full state of the document after the change - if change["fullDocument"] == nil { + if change.FullDocument == nil { //TODO (AK, 6/2024): find a better way to report that we need to ignore this event return nil, nil // no full document, nothing to do (probably got deleted before we got to the event in the change stream) } - fullDocument := change["fullDocument"].(bson.M) + fullDocument := change.FullDocument // convert fulldocument to BSON.Raw fullDocumentRaw, err := bson.Marshal(fullDocument) if err != nil { return nil, fmt.Errorf("failed to marshal full document: %v", err) } update = &adiomv1.Update{ - Id: []*adiomv1.BsonValue{{ - Data: idVal, - Type: uint32(idType), - }}, + Id: id, Type: adiomv1.UpdateType_UPDATE_TYPE_UPDATE, Data: fullDocumentRaw, } case "delete": - // get the id of the document that was deleted - id := change["documentKey"].(bson.M)["_id"] - // convert id to raw bson - idType, idVal, err := bson.MarshalValue(id) - if err != nil { - return nil, fmt.Errorf("failed to marshal _id: %v", err) - } update = &adiomv1.Update{ - Id: []*adiomv1.BsonValue{{ - Data: idVal, - Type: uint32(idType), - }}, + Id: id, Type: adiomv1.UpdateType_UPDATE_TYPE_DELETE, } default: @@ -696,6 +683,14 @@ func toTimestampPB(t primitive.Timestamp) *timestamppb.Timestamp { return timestamppb.New(time.Unix(int64(t.T), 0)) } +type MongoUpdate struct { + NS bson.M `bson:"ns"` + ClusterTime *primitive.Timestamp `bson:"clusterTime"` + DocumentKey bson.D `bson:"documentKey"` + FullDocument bson.M `bson:"fullDocument"` + OperationType string `bson:"operationType"` +} + // StreamUpdates implements adiomv1connect.ConnectorServiceHandler. func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.StreamUpdatesRequest], s *connect.ServerStream[adiomv1.StreamUpdatesResponse]) error { var watcher Watchable @@ -739,13 +734,13 @@ func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.Str var lastTime primitive.Timestamp for changeStream.Next(ctx) { - var change bson.M + var change MongoUpdate if err := changeStream.Decode(&change); err != nil { slog.Error(fmt.Sprintf("Failed to decode change stream event: %v", err)) continue } - update, err := convertChangeStreamEventToUpdate(change) + update, err := convertChangeStreamEventToUpdate(change, c.settings.FullDocumentKey) if err != nil { slog.Error(fmt.Sprintf("Failed to convert change stream event to data message: %v", err)) continue @@ -754,8 +749,8 @@ func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.Str continue } - db := change["ns"].(bson.M)["db"].(string) - col := change["ns"].(bson.M)["coll"].(string) + db := change.NS["db"].(string) + col := change.NS["coll"].(string) newNamespace := fmt.Sprintf("%v.%v", db, col) if currentNamespace == "" || currentNamespace != newNamespace { @@ -779,8 +774,8 @@ func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.Str updates = append(updates, update) lastResumeToken = changeStream.ResumeToken() - if lt, ok := change["clusterTime"].(primitive.Timestamp); ok { - lastTime = lt + if change.ClusterTime != nil { + lastTime = *change.ClusterTime } if changeStream.RemainingBatchLength() == 0 { @@ -914,63 +909,41 @@ func (c *conn) WriteData(ctx context.Context, r *connect.Request[adiomv1.WriteDa return connect.NewResponse(&adiomv1.WriteDataResponse{}), nil } -type dataIdIndex struct { - dataId []byte - index int -} - -// returns the new item or existing item, and whether or not a new item was added -func addToIdIndexMap2(m map[int][]*dataIdIndex, update *adiomv1.Update) (*dataIdIndex, bool) { - hasher := xxhash.New() - _, _ = hasher.Write(update.GetId()[0].GetData()) - h := int(hasher.Sum64()) - items, found := m[h] - if found { - for _, item := range items { - if slices.Equal(item.dataId, update.GetId()[0].GetData()) { - return item, false - } - } - } - item := &dataIdIndex{update.GetId()[0].GetData(), -1} - m[h] = append(items, item) - return item, true -} - // WriteUpdates implements adiomv1connect.ConnectorServiceHandler. func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.WriteUpdatesRequest]) (*connect.Response[adiomv1.WriteUpdatesResponse], error) { col, _, ok := GetCol(c.client, r.Msg.GetNamespace()) if !ok { return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("namespace should be fully qualified")) } + updates := util.KeepLastUpdate(r.Msg.GetUpdates()) var models []mongo.WriteModel - // keeps track of the index in models for a particular document because we want all ids to be unique in the batch - hashToDataIdIndex := map[int][]*dataIdIndex{} - - for _, update := range r.Msg.GetUpdates() { - idType := bsontype.Type(update.GetId()[0].GetType()) - + for _, update := range updates { + var idFilter bson.D + if len(update.GetId()) == 0 { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("err update with unexpected empty id")) + } + for _, idPart := range update.GetId() { + key := idPart.GetName() + // For backwards compatibility- we used to pass no key + if key == "" { + key = "_id" + } + if !c.settings.FullDocumentKey && key != "_id" { + continue + } + typ := bsontype.Type(idPart.GetType()) + idFilter = append(idFilter, bson.E{Key: key, Value: bson.RawValue{Type: typ, Value: idPart.GetData()}}) + } + if len(idFilter) == 0 { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("err with _id not found- enable full-document-key or ensure an id part has the name _id: %v", update.GetId())) + } switch update.GetType() { case adiomv1.UpdateType_UPDATE_TYPE_INSERT, adiomv1.UpdateType_UPDATE_TYPE_UPDATE: - dii, isNew := addToIdIndexMap2(hashToDataIdIndex, update) - idFilter := bson.D{{Key: "_id", Value: bson.RawValue{Type: idType, Value: update.GetId()[0].GetData()}}} model := mongo.NewReplaceOneModel().SetFilter(idFilter).SetReplacement(bson.Raw(update.GetData())).SetUpsert(true) - if isNew { - dii.index = len(models) - models = append(models, model) - } else { - models[dii.index] = model - } + models = append(models, model) case adiomv1.UpdateType_UPDATE_TYPE_DELETE: - dii, isNew := addToIdIndexMap2(hashToDataIdIndex, update) - idFilter := bson.D{{Key: "_id", Value: bson.RawValue{Type: idType, Value: update.GetId()[0].GetData()}}} model := mongo.NewDeleteOneModel().SetFilter(idFilter) - if isNew { - dii.index = len(models) - models = append(models, model) - } else { - models[dii.index] = model - } + models = append(models, model) } } diff --git a/internal/app/options/connectorflags.go b/internal/app/options/connectorflags.go index e75ac6d7..a3c719fa 100644 --- a/internal/app/options/connectorflags.go +++ b/internal/app/options/connectorflags.go @@ -760,6 +760,11 @@ func MongoFlags(settings *mongo.ConnectorSettings) []cli.Flag { Name: "skip-batch-overwrite", Destination: &settings.SkipBatchOverwrite, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "full-document-key", + Usage: "uses the full document key instead of just _id (except for batch overwrites)", + Destination: &settings.FullDocumentKey, + }), altsrc.NewDurationFlag(&cli.DurationFlag{ Name: "server-timeout", Required: false,