diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index 159297ef0a..8702d6d39e 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -1521,7 +1521,13 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - if vr.Type() == bsontype.Null { + // If BSON value is null and the go value is a pointer, then don't call + // UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e., + // non-nil), encountering null in BSON will result in the pointer being + // directly set to nil here. Since the pointer is being replaced with nil, + // there is no opportunity (or reason) for the custom UnmarshalBSONValue logic + // to be called. + if vr.Type() == bsontype.Null && val.Kind() == reflect.Ptr { val.Set(reflect.Zero(val.Type())) return vr.ReadNull() diff --git a/bson/unmarshal.go b/bson/unmarshal.go index 66da17ee01..d749ba373b 100644 --- a/bson/unmarshal.go +++ b/bson/unmarshal.go @@ -41,6 +41,9 @@ type ValueUnmarshaler interface { // Unmarshal parses the BSON-encoded data and stores the result in the value // pointed to by val. If val is nil or not a pointer, Unmarshal returns // InvalidUnmarshalError. +// +// When unmarshaling BSON, if the BSON value is null and the Go value is a +// pointer, the pointer is set to nil without calling UnmarshalBSONValue. func Unmarshal(data []byte, val interface{}) error { return UnmarshalWithRegistry(DefaultRegistry, data, val) } diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index ef91da1659..3455deeaaa 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -93,6 +94,29 @@ func TestUnmarshalValue(t *testing.T) { }) } +func TestInitializedPointerDataWithBSONNull(t *testing.T) { + // Set up the test case with initialized pointers. + tc := unmarshalBehaviorTestCase{ + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{}, + BSONPtrTracker: &unmarshalBSONCallTracker{}, + } + + // Create BSON data where the '*_ptr_tracker' fields are explicitly set to + // null. + bytes := docToBytes(D{ + {Key: "bv_ptr_tracker", Value: nil}, + {Key: "b_ptr_tracker", Value: nil}, + }) + + // Unmarshal the BSON data into the test case struct. This should set the + // pointer fields to nil due to the BSON null value. + err := Unmarshal(bytes, &tc) + require.NoError(t, err) + + assert.Nil(t, tc.BSONValuePtrTracker) + assert.Nil(t, tc.BSONPtrTracker) +} + // tests covering GODRIVER-2779 func BenchmarkSliceCodecUnmarshal(b *testing.B) { benchmarks := []struct { diff --git a/bson/unmarshaling_cases_test.go b/bson/unmarshaling_cases_test.go index dd38369bff..358088fe84 100644 --- a/bson/unmarshaling_cases_test.go +++ b/bson/unmarshaling_cases_test.go @@ -11,6 +11,7 @@ import ( "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" ) type unmarshalingTestCase struct { @@ -114,6 +115,26 @@ func unmarshalingTestCases() []unmarshalingTestCase { }, data: docToBytes(D{{"fooBar", int32(10)}}), }, + { + name: "nil pointer and non-pointer type with literal null BSON", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: nil, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: nil}, + {Key: "bv_ptr_tracker", Value: nil}, + {Key: "b_tracker", Value: nil}, + {Key: "b_ptr_tracker", Value: nil}, + }), + }, // GODRIVER-2252 // Test that a struct of pointer types with UnmarshalBSON functions defined marshal and // unmarshal to the same Go values when the pointer values are "nil". @@ -174,6 +195,50 @@ func unmarshalingTestCases() []unmarshalingTestCase { want: &valNonPtrStruct, data: docToBytes(valNonPtrStruct), }, + { + name: "nil pointer and non-pointer type with BSON minkey", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{ + called: true, + }, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: primitive.MinKey{}}, + {Key: "bv_ptr_tracker", Value: primitive.MinKey{}}, + {Key: "b_tracker", Value: primitive.MinKey{}}, + {Key: "b_ptr_tracker", Value: primitive.MinKey{}}, + }), + }, + { + name: "nil pointer and non-pointer type with BSON maxkey", + sType: reflect.TypeOf(unmarshalBehaviorTestCase{}), + want: &unmarshalBehaviorTestCase{ + BSONValueTracker: unmarshalBSONValueCallTracker{ + called: true, + }, + BSONValuePtrTracker: &unmarshalBSONValueCallTracker{ + called: true, + }, + BSONTracker: unmarshalBSONCallTracker{ + called: true, + }, + BSONPtrTracker: nil, + }, + data: docToBytes(D{ + {Key: "bv_tracker", Value: primitive.MaxKey{}}, + {Key: "bv_ptr_tracker", Value: primitive.MaxKey{}}, + {Key: "b_tracker", Value: primitive.MaxKey{}}, + {Key: "b_ptr_tracker", Value: primitive.MaxKey{}}, + }), + }, } } @@ -269,3 +334,39 @@ func (ms *myString) UnmarshalBSON(bytes []byte) error { *ms = myString(s) return nil } + +// unmarshalBSONValueCallTracker is a test struct that tracks whether the +// UnmarshalBSONValue method has been called. +type unmarshalBSONValueCallTracker struct { + called bool // called is set to true when UnmarshalBSONValue is invoked. +} + +var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{} + +// unmarshalBSONCallTracker is a test struct that tracks whether the +// UnmarshalBSON method has been called. +type unmarshalBSONCallTracker struct { + called bool // called is set to true when UnmarshalBSON is invoked. +} + +// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface. +var _ Unmarshaler = &unmarshalBSONCallTracker{} + +// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON +// unmarshaling behavior. +type unmarshalBehaviorTestCase struct { + BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value. + BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer. + BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value. + BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer. +} + +func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(bsontype.Type, []byte) error { + tracker.called = true + return nil +} + +func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error { + tracker.called = true + return nil +}