Skip to content

Commit b296809

Browse files
authored
GODRIVER-3286 BSON Binary vector subtype support. (#1919)
1 parent 8eb83cb commit b296809

12 files changed

+873
-42
lines changed

bson/bson_binary_vector_spec_test.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
// Copyright (C) MongoDB, Inc. 2024-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bson
8+
9+
import (
10+
"encoding/hex"
11+
"encoding/json"
12+
"fmt"
13+
"math"
14+
"os"
15+
"path"
16+
"testing"
17+
18+
"go.mongodb.org/mongo-driver/v2/internal/require"
19+
)
20+
21+
const bsonBinaryVectorDir = "../testdata/bson-binary-vector/"
22+
23+
type bsonBinaryVectorTests struct {
24+
Description string `json:"description"`
25+
TestKey string `json:"test_key"`
26+
Tests []bsonBinaryVectorTestCase `json:"tests"`
27+
}
28+
29+
type bsonBinaryVectorTestCase struct {
30+
Description string `json:"description"`
31+
Valid bool `json:"valid"`
32+
Vector []interface{} `json:"vector"`
33+
DtypeHex string `json:"dtype_hex"`
34+
DtypeAlias string `json:"dtype_alias"`
35+
Padding int `json:"padding"`
36+
CanonicalBson string `json:"canonical_bson"`
37+
}
38+
39+
func TestBsonBinaryVectorSpec(t *testing.T) {
40+
t.Parallel()
41+
42+
jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
43+
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)
44+
45+
for _, file := range jsonFiles {
46+
filepath := path.Join(bsonBinaryVectorDir, file)
47+
content, err := os.ReadFile(filepath)
48+
require.NoErrorf(t, err, "reading test file %s", filepath)
49+
50+
var tests bsonBinaryVectorTests
51+
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)
52+
53+
t.Run(tests.Description, func(t *testing.T) {
54+
t.Parallel()
55+
56+
for _, test := range tests.Tests {
57+
test := test
58+
t.Run(test.Description, func(t *testing.T) {
59+
t.Parallel()
60+
61+
runBsonBinaryVectorTest(t, tests.TestKey, test)
62+
})
63+
}
64+
})
65+
}
66+
67+
t.Run("FLOAT32 with padding", func(t *testing.T) {
68+
t.Parallel()
69+
70+
t.Run("Unmarshaling", func(t *testing.T) {
71+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Float32Vector, 3}}}}
72+
b, err := Marshal(val)
73+
require.NoError(t, err, "marshaling test BSON")
74+
var got struct {
75+
Vector Vector
76+
}
77+
err = Unmarshal(b, &got)
78+
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
79+
})
80+
})
81+
82+
t.Run("INT8 with padding", func(t *testing.T) {
83+
t.Parallel()
84+
85+
t.Run("Unmarshaling", func(t *testing.T) {
86+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{Int8Vector, 3}}}}
87+
b, err := Marshal(val)
88+
require.NoError(t, err, "marshaling test BSON")
89+
var got struct {
90+
Vector Vector
91+
}
92+
err = Unmarshal(b, &got)
93+
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
94+
})
95+
})
96+
97+
t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
98+
t.Parallel()
99+
100+
t.Run("Marshaling", func(t *testing.T) {
101+
_, err := NewPackedBitVector(nil, 1)
102+
require.EqualError(t, err, errNonZeroVectorPadding.Error())
103+
})
104+
t.Run("Unmarshaling", func(t *testing.T) {
105+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 1}}}}
106+
b, err := Marshal(val)
107+
require.NoError(t, err, "marshaling test BSON")
108+
var got struct {
109+
Vector Vector
110+
}
111+
err = Unmarshal(b, &got)
112+
require.ErrorContains(t, err, errNonZeroVectorPadding.Error())
113+
})
114+
})
115+
116+
t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
117+
t.Parallel()
118+
119+
t.Run("Marshaling", func(t *testing.T) {
120+
_, err := NewPackedBitVector(nil, 8)
121+
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
122+
})
123+
t.Run("Unmarshaling", func(t *testing.T) {
124+
val := D{{"vector", Binary{Subtype: TypeBinaryVector, Data: []byte{PackedBitVector, 8}}}}
125+
b, err := Marshal(val)
126+
require.NoError(t, err, "marshaling test BSON")
127+
var got struct {
128+
Vector Vector
129+
}
130+
err = Unmarshal(b, &got)
131+
require.ErrorContains(t, err, errVectorPaddingTooLarge.Error())
132+
})
133+
})
134+
}
135+
136+
// TODO: This test may be added into the spec tests.
137+
func TestFloat32VectorWithInsufficientData(t *testing.T) {
138+
t.Parallel()
139+
140+
val := Binary{Subtype: TypeBinaryVector}
141+
142+
for _, tc := range [][]byte{
143+
{Float32Vector, 0, 42},
144+
{Float32Vector, 0, 42, 42},
145+
{Float32Vector, 0, 42, 42, 42},
146+
147+
{Float32Vector, 0, 42, 42, 42, 42, 42},
148+
{Float32Vector, 0, 42, 42, 42, 42, 42, 42},
149+
{Float32Vector, 0, 42, 42, 42, 42, 42, 42, 42},
150+
} {
151+
t.Run(fmt.Sprintf("marshaling %d bytes", len(tc)-2), func(t *testing.T) {
152+
val.Data = tc
153+
b, err := Marshal(D{{"vector", val}})
154+
require.NoError(t, err, "marshaling test BSON")
155+
var got struct {
156+
Vector Vector
157+
}
158+
err = Unmarshal(b, &got)
159+
require.ErrorContains(t, err, errInsufficientVectorData.Error())
160+
})
161+
}
162+
}
163+
164+
func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
165+
v := make([]T, len(s))
166+
for i, e := range s {
167+
f := math.NaN()
168+
switch val := e.(type) {
169+
case float64:
170+
f = val
171+
case string:
172+
if val == "inf" {
173+
f = math.Inf(0)
174+
} else if val == "-inf" {
175+
f = math.Inf(-1)
176+
}
177+
}
178+
v[i] = T(f)
179+
}
180+
return v
181+
}
182+
183+
func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
184+
testVector := make(map[string]Vector)
185+
switch alias := test.DtypeHex; alias {
186+
case "0x03":
187+
testVector[testKey] = Vector{
188+
dType: Int8Vector,
189+
int8Data: convertSlice[int8](test.Vector),
190+
}
191+
case "0x27":
192+
testVector[testKey] = Vector{
193+
dType: Float32Vector,
194+
float32Data: convertSlice[float32](test.Vector),
195+
}
196+
case "0x10":
197+
testVector[testKey] = Vector{
198+
dType: PackedBitVector,
199+
bitData: convertSlice[byte](test.Vector),
200+
bitPadding: uint8(test.Padding),
201+
}
202+
default:
203+
t.Fatalf("unsupported vector type: %s", alias)
204+
}
205+
206+
testBSON, err := hex.DecodeString(test.CanonicalBson)
207+
require.NoError(t, err, "decoding canonical BSON")
208+
209+
t.Run("Unmarshaling", func(t *testing.T) {
210+
skipCases := map[string]string{
211+
"FLOAT32 with padding": "run in alternative case",
212+
"Overflow Vector INT8": "compile-time restriction",
213+
"Underflow Vector INT8": "compile-time restriction",
214+
"INT8 with padding": "run in alternative case",
215+
"INT8 with float inputs": "compile-time restriction",
216+
"Overflow Vector PACKED_BIT": "compile-time restriction",
217+
"Underflow Vector PACKED_BIT": "compile-time restriction",
218+
"Vector with float values PACKED_BIT": "compile-time restriction",
219+
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
220+
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
221+
"Negative padding PACKED_BIT": "compile-time restriction",
222+
}
223+
if reason, ok := skipCases[test.Description]; ok {
224+
t.Skipf("skip test case %s: %s", test.Description, reason)
225+
}
226+
227+
t.Parallel()
228+
229+
var got map[string]Vector
230+
err := Unmarshal(testBSON, &got)
231+
require.NoError(t, err)
232+
require.Equal(t, testVector, got)
233+
})
234+
235+
t.Run("Marshaling", func(t *testing.T) {
236+
skipCases := map[string]string{
237+
"FLOAT32 with padding": "private padding field",
238+
"Overflow Vector INT8": "compile-time restriction",
239+
"Underflow Vector INT8": "compile-time restriction",
240+
"INT8 with padding": "private padding field",
241+
"INT8 with float inputs": "compile-time restriction",
242+
"Overflow Vector PACKED_BIT": "compile-time restriction",
243+
"Underflow Vector PACKED_BIT": "compile-time restriction",
244+
"Vector with float values PACKED_BIT": "compile-time restriction",
245+
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
246+
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
247+
"Negative padding PACKED_BIT": "compile-time restriction",
248+
}
249+
if reason, ok := skipCases[test.Description]; ok {
250+
t.Skipf("skip test case %s: %s", test.Description, reason)
251+
}
252+
253+
t.Parallel()
254+
255+
got, err := Marshal(testVector)
256+
require.NoError(t, err)
257+
require.Equal(t, testBSON, got)
258+
})
259+
}

bson/bson_corpus_spec_test.go

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,15 @@ func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
217217
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
218218
var doc D
219219
err := Unmarshal(b, &doc)
220-
expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType))
220+
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
221221
return doc
222222
}
223223

224224
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
225225
// canonical BSON (cB)
226226
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
227227
actual, err := Marshal(doc)
228-
expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType))
228+
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)
229229

230230
if diff := cmp.Diff(cB, actual); diff != "" {
231231
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
@@ -261,7 +261,7 @@ func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
261261
// nativeToJSON encodes the native Document (doc) into an extended JSON string
262262
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
263263
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
264-
expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType))
264+
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)
265265

266266
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
267267
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
@@ -288,7 +288,7 @@ func runTest(t *testing.T, file string) {
288288
t.Run(v.Description, func(t *testing.T) {
289289
// get canonical BSON
290290
cB, err := hex.DecodeString(v.CanonicalBson)
291-
expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description))
291+
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)
292292

293293
// get canonical extended JSON
294294
var compactEJ bytes.Buffer
@@ -341,7 +341,7 @@ func runTest(t *testing.T, file string) {
341341
/*** degenerate BSON round-trip tests (if exists) ***/
342342
if v.DegenerateBSON != nil {
343343
dB, err := hex.DecodeString(*v.DegenerateBSON)
344-
expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description))
344+
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)
345345

346346
doc = bsonToNative(t, dB, "degenerate", v.Description)
347347

@@ -377,7 +377,7 @@ func runTest(t *testing.T, file string) {
377377
for _, d := range test.DecodeErrors {
378378
t.Run(d.Description, func(t *testing.T) {
379379
b, err := hex.DecodeString(d.Bson)
380-
expectNoError(t, err, d.Description)
380+
require.NoError(t, err, d.Description)
381381

382382
var doc D
383383
err = Unmarshal(b, &doc)
@@ -392,12 +392,12 @@ func runTest(t *testing.T, file string) {
392392
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
393393

394394
if invalidString || invalidDBPtr {
395-
expectNoError(t, err, d.Description)
395+
require.NoError(t, err, d.Description)
396396
return
397397
}
398398
}
399399

400-
expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description))
400+
require.Errorf(t, err, "%s: expected decode error", d.Description)
401401
})
402402
}
403403
})
@@ -418,7 +418,7 @@ func runTest(t *testing.T, file string) {
418418
if strings.Contains(p.Description, "Null") {
419419
_, err = Marshal(doc)
420420
}
421-
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
421+
require.Errorf(t, err, "%s: expected parse error", p.Description)
422422
default:
423423
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
424424
t.Fail()
@@ -431,31 +431,13 @@ func runTest(t *testing.T, file string) {
431431

432432
func Test_BsonCorpus(t *testing.T) {
433433
jsonFiles, err := findJSONFilesInDir(dataDir)
434-
if err != nil {
435-
t.Fatalf("error finding JSON files in %s: %v", dataDir, err)
436-
}
434+
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)
437435

438436
for _, file := range jsonFiles {
439437
runTest(t, file)
440438
}
441439
}
442440

443-
func expectNoError(t *testing.T, err error, desc string) {
444-
if err != nil {
445-
t.Helper()
446-
t.Errorf("%s: Unepexted error: %v", desc, err)
447-
t.FailNow()
448-
}
449-
}
450-
451-
func expectError(t *testing.T, err error, desc string) {
452-
if err == nil {
453-
t.Helper()
454-
t.Errorf("%s: Expected error", desc)
455-
t.FailNow()
456-
}
457-
}
458-
459441
func TestRelaxedUUIDValidation(t *testing.T) {
460442
testCases := []struct {
461443
description string

0 commit comments

Comments
 (0)