Skip to content

Commit 84b88b3

Browse files
committed
GODRIVER-3472: Address code review improvements
1 parent fa20f11 commit 84b88b3

File tree

6 files changed

+341
-109
lines changed

6 files changed

+341
-109
lines changed

bson/primitive_codecs_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package bson
88

99
import (
1010
"bytes"
11+
"encoding/binary"
1112
"encoding/json"
1213
"errors"
1314
"fmt"
@@ -1116,3 +1117,50 @@ func compareDecimal128(d1, d2 Decimal128) bool {
11161117

11171118
return true
11181119
}
1120+
1121+
func TestSliceCodec(t *testing.T) {
1122+
t.Run("[]byte is treated as binary data", func(t *testing.T) {
1123+
type testStruct struct {
1124+
B []byte `bson:"b"`
1125+
}
1126+
1127+
testData := testStruct{B: []byte{0x01, 0x02, 0x03}}
1128+
data, err := Marshal(testData)
1129+
assert.Nil(t, err, "Marshal error: %v", err)
1130+
var doc D
1131+
err = Unmarshal(data, &doc)
1132+
assert.Nil(t, err, "Unmarshal error: %v", err)
1133+
1134+
offset := 4 + 1 + 2
1135+
length := int32(binary.LittleEndian.Uint32(data[offset:]))
1136+
offset += 4 // Skip length
1137+
subtype := data[offset]
1138+
offset++ // Skip subtype
1139+
dataBytes := data[offset : offset+int(length)]
1140+
1141+
assert.Equal(t, byte(0x00), subtype, "Expected binary subtype 0x00")
1142+
assert.Equal(t, []byte{0x01, 0x02, 0x03}, dataBytes, "Binary data mismatch")
1143+
})
1144+
1145+
t.Run("[]int8 is not treated as binary data", func(t *testing.T) {
1146+
type testStruct struct {
1147+
I []int8 `bson:"i"`
1148+
}
1149+
testData := testStruct{I: []int8{1, 2, 3}}
1150+
data, err := Marshal(testData)
1151+
assert.Nil(t, err, "Marshal error: %v", err)
1152+
1153+
offset := 4 // Skip document length
1154+
assert.Equal(t, byte(0x04), data[offset], "Expected array type (0x04), got: 0x%02x", data[offset])
1155+
1156+
var result struct {
1157+
I []int32 `bson:"i"`
1158+
}
1159+
err = Unmarshal(data, &result)
1160+
assert.Nil(t, err, "Unmarshal result error: %v", err)
1161+
assert.Equal(t, 3, len(result.I), "Expected array length 3")
1162+
assert.Equal(t, int32(1), result.I[0], "Array element 0 mismatch")
1163+
assert.Equal(t, int32(2), result.I[1], "Array element 1 mismatch")
1164+
assert.Equal(t, int32(3), result.I[2], "Array element 2 mismatch")
1165+
})
1166+
}

bson/slice_codec.go

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
package bson
88

99
import (
10+
"encoding/binary"
1011
"errors"
1112
"fmt"
13+
"math"
1214
"reflect"
1315
)
1416

@@ -25,7 +27,7 @@ type sliceCodec struct {
2527
func (sc *sliceCodec) decodeVectorBinary(vr ValueReader, val reflect.Value) error {
2628
elemType := val.Type().Elem()
2729

28-
if elemType != TInt8 && elemType != TFloat32 {
30+
if elemType != tInt8 && elemType != tFloat32 {
2931
return errNotAVectorBinary
3032
}
3133

@@ -39,14 +41,14 @@ func (sc *sliceCodec) decodeVectorBinary(vr ValueReader, val reflect.Value) erro
3941
}
4042

4143
switch elemType {
42-
case TInt8:
43-
int8Slice, err := DecodeVectorInt8(data)
44+
case tInt8:
45+
int8Slice, err := decodeVectorInt8(data)
4446
if err != nil {
4547
return err
4648
}
4749
val.Set(reflect.ValueOf(int8Slice))
48-
case TFloat32:
49-
float32Slice, err := DecodeVectorFloat32(data)
50+
case tFloat32:
51+
float32Slice, err := decodeVectorFloat32(data)
5052
if err != nil {
5153
return err
5254
}
@@ -66,8 +68,9 @@ func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.
6668
return vw.WriteNull()
6769
}
6870

69-
// Treat []byte as binary data, but skip for []int8 since it's a different type
70-
// even though byte is an alias for uint8 which has the same underlying type as int8
71+
// Treat []byte as binary data, but skip for []int8 since it's a different type.
72+
// Even though byte is an alias for uint8 which has the same underlying type as int8,
73+
// we want to maintain the semantic difference between []byte (binary data) and []int8 (array of integers).
7174
if val.Type().Elem() == tByte && val.Type() != reflect.TypeOf([]int8{}) {
7275
byteSlice := make([]byte, val.Len())
7376
reflect.Copy(reflect.ValueOf(byteSlice), val)
@@ -137,12 +140,6 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.
137140
return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
138141
}
139142

140-
if vr.Type() == TypeBinary {
141-
if err := sc.decodeVectorBinary(vr, val); err != errNotAVectorBinary {
142-
return err
143-
}
144-
}
145-
146143
switch vrType := vr.Type(); vrType {
147144
case TypeArray:
148145
case TypeNull:
@@ -156,6 +153,16 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.
156153
return fmt.Errorf("cannot decode document into %s", val.Type())
157154
}
158155
case TypeBinary:
156+
// First try to decode as a vector binary
157+
err := sc.decodeVectorBinary(vr, val)
158+
if err == nil {
159+
return nil // Successfully decoded as vector
160+
}
161+
if err != errNotAVectorBinary {
162+
return err // Return any actual errors
163+
}
164+
165+
// If not a vector binary, handle as regular binary data
159166
if val.Type().Elem() != tByte {
160167
return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType)
161168
}
@@ -215,3 +222,62 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.
215222

216223
return nil
217224
}
225+
226+
// decodeVectorInt8 decodes a BSON Vector binary value (subtype 9) into a []int8 slice.
227+
// The binary data should be in the format: [<vector type> <padding> <data>]
228+
// For int8 vectors, the vector type is Int8Vector (0x03).
229+
func decodeVectorInt8(data []byte) ([]int8, error) {
230+
if len(data) < 2 {
231+
return nil, fmt.Errorf("insufficient bytes to decode vector: expected at least 2 bytes")
232+
}
233+
234+
vectorType := data[0]
235+
if vectorType != Int8Vector {
236+
return nil, fmt.Errorf("invalid vector type: expected int8 vector (0x%02x), got 0x%02x", Int8Vector, vectorType)
237+
}
238+
239+
if padding := data[1]; padding != 0 {
240+
return nil, fmt.Errorf("invalid vector: padding byte must be 0")
241+
}
242+
243+
values := make([]int8, 0, len(data)-2)
244+
for i := 2; i < len(data); i++ {
245+
values = append(values, int8(data[i]))
246+
}
247+
248+
return values, nil
249+
}
250+
251+
// decodeVectorFloat32 decodes a BSON Vector binary value (subtype 9) into a []float32 slice.
252+
// The binary data should be in the format: [<vector type> <padding> <data>]
253+
// For float32 vectors, the vector type is Float32Vector (0x27) and data must be a multiple of 4 bytes.
254+
func decodeVectorFloat32(data []byte) ([]float32, error) {
255+
if len(data) < 2 {
256+
return nil, fmt.Errorf("insufficient bytes to decode vector: expected at least 2 bytes")
257+
}
258+
259+
vectorType := data[0]
260+
if vectorType != Float32Vector {
261+
return nil, fmt.Errorf("invalid vector type: expected float32 vector (0x%02x), got 0x%02x", Float32Vector, vectorType)
262+
}
263+
264+
if padding := data[1]; padding != 0 {
265+
return nil, fmt.Errorf("invalid vector: padding byte must be 0")
266+
}
267+
268+
floatData := data[2:]
269+
if len(floatData)%4 != 0 {
270+
return nil, fmt.Errorf("invalid float32 vector: data length must be a multiple of 4")
271+
}
272+
273+
values := make([]float32, 0, len(floatData)/4)
274+
for i := 0; i < len(floatData); i += 4 {
275+
if i+4 > len(floatData) {
276+
return nil, fmt.Errorf("invalid float32 vector: truncated data")
277+
}
278+
bits := binary.LittleEndian.Uint32(floatData[i : i+4])
279+
values = append(values, math.Float32frombits(bits))
280+
}
281+
282+
return values, nil
283+
}

bson/types.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ const (
7777
)
7878

7979
var tBool = reflect.TypeOf(false)
80+
var tFloat32 = reflect.TypeOf(float32(0))
8081
var tFloat64 = reflect.TypeOf(float64(0))
82+
var tInt8 = reflect.TypeOf(int8(0))
8183
var tInt32 = reflect.TypeOf(int32(0))
8284
var tInt64 = reflect.TypeOf(int64(0))
8385
var tString = reflect.TypeOf("")

bson/vector.go

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"errors"
1212
"fmt"
1313
"math"
14-
"reflect"
1514
)
1615

1716
// BSON binary vector types as described in https://bsonspec.org/spec.html.
@@ -29,13 +28,6 @@ var (
2928
errNotAVectorBinary = errors.New("not a vector binary")
3029
)
3130

32-
var (
33-
// TInt8 is the reflect.Type for int8
34-
TInt8 = reflect.TypeOf(int8(0))
35-
// TFloat32 is the reflect.Type for float32
36-
TFloat32 = reflect.TypeOf(float32(0))
37-
)
38-
3931
type vectorTypeError struct {
4032
Method string
4133
Type byte
@@ -275,60 +267,3 @@ func newBitVector(b []byte) (Vector, error) {
275267
}
276268
return NewPackedBitVector(b[1:], b[0])
277269
}
278-
279-
// DecodeVectorInt8 decodes a BSON Vector binary value (subtype 9) into a []int8 slice.
280-
// The binary data should be in the format: [<vector type> <padding> <data>]
281-
// For int8 vectors, the vector type is 0x01.
282-
func DecodeVectorInt8(data []byte) ([]int8, error) {
283-
if len(data) < 2 {
284-
return nil, errors.New("insufficient bytes to decode vector: expected at least 2 bytes")
285-
}
286-
287-
vectorType := data[0]
288-
if vectorType != 0x01 { // Int8Vector
289-
return nil, errors.New("invalid vector type: expected int8 vector (0x01)")
290-
}
291-
292-
if padding := data[1]; padding != 0 {
293-
return nil, errors.New("invalid vector: padding byte must be 0")
294-
}
295-
values := make([]int8, 0, len(data)-2)
296-
for i := 2; i < len(data); i++ {
297-
values = append(values, int8(data[i]))
298-
}
299-
300-
return values, nil
301-
}
302-
303-
// DecodeVectorFloat32 decodes a BSON Vector binary value (subtype 9) into a []float32 slice.
304-
// The binary data should be in the format: [<vector type> <padding> <data>]
305-
// For float32 vectors, the vector type is 0x02 and data must be a multiple of 4 bytes.
306-
func DecodeVectorFloat32(data []byte) ([]float32, error) {
307-
if len(data) < 2 {
308-
return nil, errors.New("insufficient bytes to decode vector: expected at least 2 bytes")
309-
}
310-
311-
vectorType := data[0]
312-
if vectorType != 0x02 { // Float32Vector
313-
return nil, errors.New("invalid vector type: expected float32 vector (0x02)")
314-
}
315-
316-
if padding := data[1]; padding != 0 {
317-
return nil, errors.New("invalid vector: padding byte must be 0")
318-
}
319-
floatData := data[2:]
320-
if len(floatData)%4 != 0 {
321-
return nil, errors.New("invalid float32 vector: data length must be a multiple of 4")
322-
}
323-
324-
values := make([]float32, 0, len(floatData)/4)
325-
for i := 0; i < len(floatData); i += 4 {
326-
if i+4 > len(floatData) {
327-
return nil, errors.New("invalid float32 vector: truncated data")
328-
}
329-
bits := binary.LittleEndian.Uint32(floatData[i : i+4])
330-
values = append(values, math.Float32frombits(bits))
331-
}
332-
333-
return values, nil
334-
}

0 commit comments

Comments
 (0)