Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions connectors/cosmos/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ func convertChangeStreamEventToUpdate(change bson.M) (*adiomv1.Update, error) {
Id: []*adiomv1.BsonValue{{
Data: idVal,
Type: uint32(idType),
Name: "_id",
}},
Type: adiomv1.UpdateType_UPDATE_TYPE_DELETE,
}, nil
Expand All @@ -299,6 +300,7 @@ func convertChangeStreamEventToUpdate(change bson.M) (*adiomv1.Update, error) {
Id: []*adiomv1.BsonValue{{
Data: idVal,
Type: uint32(idType),
Name: "_id",
}},
Type: adiomv1.UpdateType_UPDATE_TYPE_UPDATE,
Data: fullDocumentRaw,
Expand Down Expand Up @@ -333,6 +335,7 @@ func checkForDeletes(ctx context.Context, client *mongo.Client, witnessClient *m
Id: []*adiomv1.BsonValue{{
Data: idVal,
Type: uint32(idType),
Name: "_id",
}},
Type: adiomv1.UpdateType_UPDATE_TYPE_DELETE,
})
Expand Down
153 changes: 63 additions & 90 deletions connectors/mongo/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import (
"time"

"connectrpc.com/connect"
"github.com/adiom-data/dsync/connectors/util"
adiomv1 "github.com/adiom-data/dsync/gen/adiom/v1"
"github.com/adiom-data/dsync/gen/adiom/v1/adiomv1connect"
"github.com/adiom-data/dsync/protocol/iface"
"github.com/cespare/xxhash"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
Expand All @@ -41,6 +41,7 @@ type ConnectorSettings struct {
MaxPageSize int
PerNamespaceStreams bool
SkipBatchOverwrite bool
FullDocumentKey bool

Query string // query filter, as a v2 Extended JSON string, e.g., '{\"x\":{\"$gt\":1}}'"
}
Expand Down Expand Up @@ -611,75 +612,61 @@ func createChangeStreamNamespaceFilterFromNamespaces(namespaces []iface.Namespac
return bson.D{{"$or", filters}}
}

func convertChangeStreamEventToUpdate(change bson.M) (*adiomv1.Update, error) {
func convertChangeStreamEventToUpdate(change MongoUpdate, fullDocumentKey bool) (*adiomv1.Update, error) {
// slog.Debug(fmt.Sprintf("Converting change stream event %v", change))

optype := change["operationType"].(string)
optype := change.OperationType
var update *adiomv1.Update

switch optype {
case "insert":
// get the id of the document that was inserted
id := change["documentKey"].(bson.M)["_id"]
// convert id to raw bson
idType, idVal, err := bson.MarshalValue(id)
var id []*adiomv1.BsonValue
for _, k := range change.DocumentKey {
if !fullDocumentKey && k.Key != "_id" {
continue
}
idType, idVal, err := bson.MarshalValue(k.Value)
if err != nil {
return nil, fmt.Errorf("failed to marshal _id: %v", err)
return nil, fmt.Errorf("failed to marshal %v: %v", k.Key, err)
}
fullDocument := change["fullDocument"].(bson.M)
id = append(id, &adiomv1.BsonValue{
Data: idVal,
Type: uint32(idType),
Name: k.Key,
})
}

switch optype {
case "insert":
fullDocument := change.FullDocument
// convert fulldocument to BSON.Raw
fullDocumentRaw, err := bson.Marshal(fullDocument)
if err != nil {
return nil, fmt.Errorf("failed to marshal full document: %v", err)
}
update = &adiomv1.Update{
Id: []*adiomv1.BsonValue{{
Data: idVal,
Type: uint32(idType),
}},
Id: id,
Type: adiomv1.UpdateType_UPDATE_TYPE_INSERT,
Data: fullDocumentRaw,
}
case "update", "replace":
// get the id of the document that was changed
id := change["documentKey"].(bson.M)["_id"]
// convert id to raw bson
idType, idVal, err := bson.MarshalValue(id)
if err != nil {
return nil, fmt.Errorf("failed to marshal _id: %v", err)
}
// get the full state of the document after the change
if change["fullDocument"] == nil {
if change.FullDocument == nil {
//TODO (AK, 6/2024): find a better way to report that we need to ignore this event
return nil, nil // no full document, nothing to do (probably got deleted before we got to the event in the change stream)
}
fullDocument := change["fullDocument"].(bson.M)
fullDocument := change.FullDocument
// convert fulldocument to BSON.Raw
fullDocumentRaw, err := bson.Marshal(fullDocument)
if err != nil {
return nil, fmt.Errorf("failed to marshal full document: %v", err)
}
update = &adiomv1.Update{
Id: []*adiomv1.BsonValue{{
Data: idVal,
Type: uint32(idType),
}},
Id: id,
Type: adiomv1.UpdateType_UPDATE_TYPE_UPDATE,
Data: fullDocumentRaw,
}
case "delete":
// get the id of the document that was deleted
id := change["documentKey"].(bson.M)["_id"]
// convert id to raw bson
idType, idVal, err := bson.MarshalValue(id)
if err != nil {
return nil, fmt.Errorf("failed to marshal _id: %v", err)
}
update = &adiomv1.Update{
Id: []*adiomv1.BsonValue{{
Data: idVal,
Type: uint32(idType),
}},
Id: id,
Type: adiomv1.UpdateType_UPDATE_TYPE_DELETE,
}
default:
Expand All @@ -696,6 +683,14 @@ func toTimestampPB(t primitive.Timestamp) *timestamppb.Timestamp {
return timestamppb.New(time.Unix(int64(t.T), 0))
}

type MongoUpdate struct {
NS bson.M `bson:"ns"`
ClusterTime *primitive.Timestamp `bson:"clusterTime"`
DocumentKey bson.D `bson:"documentKey"`
FullDocument bson.M `bson:"fullDocument"`
OperationType string `bson:"operationType"`
}

// StreamUpdates implements adiomv1connect.ConnectorServiceHandler.
func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.StreamUpdatesRequest], s *connect.ServerStream[adiomv1.StreamUpdatesResponse]) error {
var watcher Watchable
Expand Down Expand Up @@ -739,13 +734,13 @@ func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.Str
var lastTime primitive.Timestamp

for changeStream.Next(ctx) {
var change bson.M
var change MongoUpdate
if err := changeStream.Decode(&change); err != nil {
slog.Error(fmt.Sprintf("Failed to decode change stream event: %v", err))
continue
}

update, err := convertChangeStreamEventToUpdate(change)
update, err := convertChangeStreamEventToUpdate(change, c.settings.FullDocumentKey)
if err != nil {
slog.Error(fmt.Sprintf("Failed to convert change stream event to data message: %v", err))
continue
Expand All @@ -754,8 +749,8 @@ func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.Str
continue
}

db := change["ns"].(bson.M)["db"].(string)
col := change["ns"].(bson.M)["coll"].(string)
db := change.NS["db"].(string)
col := change.NS["coll"].(string)
Comment on lines +752 to +753
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unsafe type assertions on change.NS will panic on unexpected data.

change.NS["db"].(string) and change.NS["coll"].(string) are bare type assertions. If the key is missing or the value isn't a string, this panics and kills the stream goroutine. While MongoDB guarantees these for standard CRUD events, defensive assertions are cheap insurance.

🛡️ Suggested defensive approach
-		db := change.NS["db"].(string)
-		col := change.NS["coll"].(string)
+		db, _ := change.NS["db"].(string)
+		col, _ := change.NS["coll"].(string)
+		if db == "" || col == "" {
+			slog.Error(fmt.Sprintf("Skipping change event with invalid namespace: %v", change.NS))
+			continue
+		}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
db := change.NS["db"].(string)
col := change.NS["coll"].(string)
db, _ := change.NS["db"].(string)
col, _ := change.NS["coll"].(string)
if db == "" || col == "" {
slog.Error(fmt.Sprintf("Skipping change event with invalid namespace: %v", change.NS))
continue
}
🤖 Prompt for AI Agents
In `@connectors/mongo/conn.go` around lines 752 - 753, The code uses unsafe type
assertions change.NS["db"].(string) and change.NS["coll"].(string) which can
panic; update the handler around change.NS to perform safe type checks (value,
ok := change.NS["db"]; db, ok2 := value.(string)) and similarly for "coll",
logging an error and skipping the change or returning early if either key is
missing or not a string. Locate this in the change stream processing where
variables db and col are set and replace the bare assertions with the safe-ok
pattern, ensuring you handle the fallback path (log the unexpected shape and
continue) instead of letting the goroutine crash.

newNamespace := fmt.Sprintf("%v.%v", db, col)

if currentNamespace == "" || currentNamespace != newNamespace {
Expand All @@ -779,8 +774,8 @@ func (c *conn) StreamUpdates(ctx context.Context, r *connect.Request[adiomv1.Str

updates = append(updates, update)
lastResumeToken = changeStream.ResumeToken()
if lt, ok := change["clusterTime"].(primitive.Timestamp); ok {
lastTime = lt
if change.ClusterTime != nil {
lastTime = *change.ClusterTime
}

if changeStream.RemainingBatchLength() == 0 {
Expand Down Expand Up @@ -914,63 +909,41 @@ func (c *conn) WriteData(ctx context.Context, r *connect.Request[adiomv1.WriteDa
return connect.NewResponse(&adiomv1.WriteDataResponse{}), nil
}

type dataIdIndex struct {
dataId []byte
index int
}

// returns the new item or existing item, and whether or not a new item was added
func addToIdIndexMap2(m map[int][]*dataIdIndex, update *adiomv1.Update) (*dataIdIndex, bool) {
hasher := xxhash.New()
_, _ = hasher.Write(update.GetId()[0].GetData())
h := int(hasher.Sum64())
items, found := m[h]
if found {
for _, item := range items {
if slices.Equal(item.dataId, update.GetId()[0].GetData()) {
return item, false
}
}
}
item := &dataIdIndex{update.GetId()[0].GetData(), -1}
m[h] = append(items, item)
return item, true
}

// WriteUpdates implements adiomv1connect.ConnectorServiceHandler.
func (c *conn) WriteUpdates(ctx context.Context, r *connect.Request[adiomv1.WriteUpdatesRequest]) (*connect.Response[adiomv1.WriteUpdatesResponse], error) {
col, _, ok := GetCol(c.client, r.Msg.GetNamespace())
if !ok {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("namespace should be fully qualified"))
}
updates := util.KeepLastUpdate(r.Msg.GetUpdates())
var models []mongo.WriteModel
// keeps track of the index in models for a particular document because we want all ids to be unique in the batch
hashToDataIdIndex := map[int][]*dataIdIndex{}

for _, update := range r.Msg.GetUpdates() {
idType := bsontype.Type(update.GetId()[0].GetType())

for _, update := range updates {
var idFilter bson.D
if len(update.GetId()) == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("err update with unexpected empty id"))
}
for _, idPart := range update.GetId() {
key := idPart.GetName()
// For backwards compatibility- we used to pass no key
if key == "" {
key = "_id"
}
if !c.settings.FullDocumentKey && key != "_id" {
continue
}
typ := bsontype.Type(idPart.GetType())
idFilter = append(idFilter, bson.E{Key: key, Value: bson.RawValue{Type: typ, Value: idPart.GetData()}})
}
if len(idFilter) == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("err with _id not found- enable full-document-key or ensure an id part has the name _id: %v", update.GetId()))
}
switch update.GetType() {
case adiomv1.UpdateType_UPDATE_TYPE_INSERT, adiomv1.UpdateType_UPDATE_TYPE_UPDATE:
dii, isNew := addToIdIndexMap2(hashToDataIdIndex, update)
idFilter := bson.D{{Key: "_id", Value: bson.RawValue{Type: idType, Value: update.GetId()[0].GetData()}}}
model := mongo.NewReplaceOneModel().SetFilter(idFilter).SetReplacement(bson.Raw(update.GetData())).SetUpsert(true)
if isNew {
dii.index = len(models)
models = append(models, model)
} else {
models[dii.index] = model
}
models = append(models, model)
case adiomv1.UpdateType_UPDATE_TYPE_DELETE:
dii, isNew := addToIdIndexMap2(hashToDataIdIndex, update)
idFilter := bson.D{{Key: "_id", Value: bson.RawValue{Type: idType, Value: update.GetId()[0].GetData()}}}
model := mongo.NewDeleteOneModel().SetFilter(idFilter)
if isNew {
dii.index = len(models)
models = append(models, model)
} else {
models[dii.index] = model
}
models = append(models, model)
}
}

Expand Down
5 changes: 5 additions & 0 deletions internal/app/options/connectorflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,11 @@ func MongoFlags(settings *mongo.ConnectorSettings) []cli.Flag {
Name: "skip-batch-overwrite",
Destination: &settings.SkipBatchOverwrite,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "full-document-key",
Usage: "uses the full document key instead of just _id (except for batch overwrites)",
Destination: &settings.FullDocumentKey,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "server-timeout",
Required: false,
Expand Down