Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3470 Correct BSON unmarshaling logic for null values (#1924) #1955

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,18 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}

// If BSON value is null and the go value is a pointer, then don't call
// UnmarshalBSONValue. Even if the Go pointer is already initialized (i.e.,
// non-nil), encountering null in BSON will result in the pointer being
// directly set to nil here. Since the pointer is being replaced with nil,
// there is no opportunity (or reason) for the custom UnmarshalBSONValue logic
// to be called.
if vr.Type() == bsontype.Null && val.Kind() == reflect.Ptr {
val.Set(reflect.Zero(val.Type()))

return vr.ReadNull()
}

if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
Expand Down
3 changes: 3 additions & 0 deletions bson/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type ValueUnmarshaler interface {
// Unmarshal parses the BSON-encoded data and stores the result in the value
// pointed to by val. If val is nil or not a pointer, Unmarshal returns
// InvalidUnmarshalError.
//
// When unmarshaling BSON, if the BSON value is null and the Go value is a
// pointer, the pointer is set to nil without calling UnmarshalBSONValue.
func Unmarshal(data []byte, val interface{}) error {
return UnmarshalWithRegistry(DefaultRegistry, data, val)
}
Expand Down
24 changes: 24 additions & 0 deletions bson/unmarshal_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)

Expand Down Expand Up @@ -93,6 +94,29 @@ func TestUnmarshalValue(t *testing.T) {
})
}

func TestInitializedPointerDataWithBSONNull(t *testing.T) {
// Set up the test case with initialized pointers.
tc := unmarshalBehaviorTestCase{
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{},
BSONPtrTracker: &unmarshalBSONCallTracker{},
}

// Create BSON data where the '*_ptr_tracker' fields are explicitly set to
// null.
bytes := docToBytes(D{
{Key: "bv_ptr_tracker", Value: nil},
{Key: "b_ptr_tracker", Value: nil},
})

// Unmarshal the BSON data into the test case struct. This should set the
// pointer fields to nil due to the BSON null value.
err := Unmarshal(bytes, &tc)
require.NoError(t, err)

assert.Nil(t, tc.BSONValuePtrTracker)
assert.Nil(t, tc.BSONPtrTracker)
}

// tests covering GODRIVER-2779
func BenchmarkSliceCodecUnmarshal(b *testing.B) {
benchmarks := []struct {
Expand Down
101 changes: 101 additions & 0 deletions bson/unmarshaling_cases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)

type unmarshalingTestCase struct {
Expand Down Expand Up @@ -114,6 +115,26 @@ func unmarshalingTestCases() []unmarshalingTestCase {
},
data: docToBytes(D{{"fooBar", int32(10)}}),
},
{
name: "nil pointer and non-pointer type with literal null BSON",
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
want: &unmarshalBehaviorTestCase{
BSONValueTracker: unmarshalBSONValueCallTracker{
called: true,
},
BSONValuePtrTracker: nil,
BSONTracker: unmarshalBSONCallTracker{
called: true,
},
BSONPtrTracker: nil,
},
data: docToBytes(D{
{Key: "bv_tracker", Value: nil},
{Key: "bv_ptr_tracker", Value: nil},
{Key: "b_tracker", Value: nil},
{Key: "b_ptr_tracker", Value: nil},
}),
},
// GODRIVER-2252
// Test that a struct of pointer types with UnmarshalBSON functions defined marshal and
// unmarshal to the same Go values when the pointer values are "nil".
Expand Down Expand Up @@ -174,6 +195,50 @@ func unmarshalingTestCases() []unmarshalingTestCase {
want: &valNonPtrStruct,
data: docToBytes(valNonPtrStruct),
},
{
name: "nil pointer and non-pointer type with BSON minkey",
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
want: &unmarshalBehaviorTestCase{
BSONValueTracker: unmarshalBSONValueCallTracker{
called: true,
},
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
called: true,
},
BSONTracker: unmarshalBSONCallTracker{
called: true,
},
BSONPtrTracker: nil,
},
data: docToBytes(D{
{Key: "bv_tracker", Value: primitive.MinKey{}},
{Key: "bv_ptr_tracker", Value: primitive.MinKey{}},
{Key: "b_tracker", Value: primitive.MinKey{}},
{Key: "b_ptr_tracker", Value: primitive.MinKey{}},
}),
},
{
name: "nil pointer and non-pointer type with BSON maxkey",
sType: reflect.TypeOf(unmarshalBehaviorTestCase{}),
want: &unmarshalBehaviorTestCase{
BSONValueTracker: unmarshalBSONValueCallTracker{
called: true,
},
BSONValuePtrTracker: &unmarshalBSONValueCallTracker{
called: true,
},
BSONTracker: unmarshalBSONCallTracker{
called: true,
},
BSONPtrTracker: nil,
},
data: docToBytes(D{
{Key: "bv_tracker", Value: primitive.MaxKey{}},
{Key: "bv_ptr_tracker", Value: primitive.MaxKey{}},
{Key: "b_tracker", Value: primitive.MaxKey{}},
{Key: "b_ptr_tracker", Value: primitive.MaxKey{}},
}),
},
}
}

Expand Down Expand Up @@ -250,3 +315,39 @@ func (ms *myString) UnmarshalBSON(bytes []byte) error {
*ms = myString(s)
return nil
}

// unmarshalBSONValueCallTracker is a test struct that tracks whether the
// UnmarshalBSONValue method has been called.
type unmarshalBSONValueCallTracker struct {
called bool // called is set to true when UnmarshalBSONValue is invoked.
}

var _ ValueUnmarshaler = &unmarshalBSONValueCallTracker{}

// unmarshalBSONCallTracker is a test struct that tracks whether the
// UnmarshalBSON method has been called.
type unmarshalBSONCallTracker struct {
called bool // called is set to true when UnmarshalBSON is invoked.
}

// Ensure unmarshalBSONCallTracker implements the Unmarshaler interface.
var _ Unmarshaler = &unmarshalBSONCallTracker{}

// unmarshalBehaviorTestCase holds instances of call trackers for testing BSON
// unmarshaling behavior.
type unmarshalBehaviorTestCase struct {
BSONValueTracker unmarshalBSONValueCallTracker `bson:"bv_tracker"` // BSON value unmarshaling by value.
BSONValuePtrTracker *unmarshalBSONValueCallTracker `bson:"bv_ptr_tracker"` // BSON value unmarshaling by pointer.
BSONTracker unmarshalBSONCallTracker `bson:"b_tracker"` // BSON unmarshaling by value.
BSONPtrTracker *unmarshalBSONCallTracker `bson:"b_ptr_tracker"` // BSON unmarshaling by pointer.
}

func (tracker *unmarshalBSONValueCallTracker) UnmarshalBSONValue(bsontype.Type, []byte) error {
tracker.called = true
return nil
}

func (tracker *unmarshalBSONCallTracker) UnmarshalBSON([]byte) error {
tracker.called = true
return nil
}
Loading