Skip to content

Commit 1784fc9

Browse files
authored
[TST] Add tests for CheckCollections, make test collection setup more flexible (#4912)
## Summary This adds a test for the `CheckCollections` method. This would have caught the issue fixed in #4899. Along the way, I discovered how these tests need some love, so I started down this path by introducing a new `daotest` package that provides testing helpers for creating `Collection` values that are easier to configure. Instead of having to potentially debug all the tests that use `dao.CreateTestCollection`, I defined a "shim" that can be used as a stop-gap so that I can change the signature of this `dao.CreateTestCollection` without potentially derailing the original goal which was just to add some missing test coverage. - Improvements & Bug fixes - Added test coverage for `CheckCollections` service endpoint - New functionality - Added `daotest` package and `NewTestCollection` helper with builder/options methods for more configurability ## Test plan All the tests are passing. ## Documentation Changes N/A
1 parent e6036eb commit 1784fc9

File tree

5 files changed

+234
-49
lines changed

5 files changed

+234
-49
lines changed

go/pkg/sysdb/grpc/collection_service_test.go

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/chroma-core/chroma/go/pkg/sysdb/coordinator"
1414
"github.com/chroma-core/chroma/go/pkg/sysdb/coordinator/model"
1515
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dao"
16+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dao/daotest"
1617
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbcore"
1718
s3metastore "github.com/chroma-core/chroma/go/pkg/sysdb/metastore/s3"
1819
"github.com/chroma-core/chroma/go/pkg/types"
@@ -28,6 +29,11 @@ import (
2829
"pgregory.net/rapid"
2930
)
3031

32+
// TODO(eculver): replace most suite.NoError(err) with suite.Require().NoError(err) so the test
33+
// stops running when the error is not nil instead of continuing and causing red herrings in test output
34+
35+
// TODO(eculver): replace calls to dao.NewDefaultTestCollection with daotest.NewTestCollection
36+
3137
type CollectionServiceTestSuite struct {
3238
suite.Suite
3339
catalog *coordinator.Catalog
@@ -346,9 +352,9 @@ func (suite *CollectionServiceTestSuite) TestCreateCollection() {
346352
}
347353

348354
func (suite *CollectionServiceTestSuite) TestServer_GetCollection() {
349-
// Create a test collection
355+
// Create a test collection with a name that should not already exist in the database
350356
collectionName := "test_get_collection"
351-
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
357+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
352358
suite.NoError(err)
353359

354360
// Soft delete the collection
@@ -377,6 +383,7 @@ func (suite *CollectionServiceTestSuite) TestServer_GetCollection() {
377383

378384
func (suite *CollectionServiceTestSuite) TestServer_GetCollectionByResourceName() {
379385
tenantResourceName := "test_tenant_resource_name"
386+
// Does this need to match the daotest.TestTenantID?
380387
tenantID := "test_tenant_id"
381388
databaseName := "test_database"
382389
collectionName := "test_collection"
@@ -388,7 +395,7 @@ func (suite *CollectionServiceTestSuite) TestServer_GetCollectionByResourceName(
388395
err = dao.SetTestTenantResourceName(suite.db, tenantID, tenantResourceName)
389396
suite.NoError(err)
390397

391-
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, dim, databaseID, nil)
398+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, dim, databaseID, nil))
392399
suite.NoError(err)
393400

394401
req := &coordinatorpb.GetCollectionByResourceNameRequest{
@@ -447,7 +454,7 @@ func (suite *CollectionServiceTestSuite) TestServer_FlushCollectionCompaction()
447454
log.Info("TestServer_FlushCollectionCompaction")
448455
// create test collection
449456
collectionName := "collection_service_test_flush_collection_compaction"
450-
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
457+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
451458
suite.NoError(err)
452459

453460
// flush collection compaction
@@ -597,7 +604,7 @@ func (suite *CollectionServiceTestSuite) TestServer_FlushCollectionCompaction()
597604
// Send FlushCollectionCompaction for a collection that is soft deleted.
598605
// It should fail with a failed precondition error.
599606
// Create collection and soft-delete it.
600-
collectionID, err = dao.CreateTestCollection(suite.db, "test_flush_collection_compaction_soft_delete", 128, suite.databaseId, nil)
607+
collectionID, err = dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection("test_flush_collection_compaction_soft_delete", 128, suite.databaseId, nil))
601608
suite.NoError(err)
602609
suite.s.coordinator.SoftDeleteCollection(context.Background(), &model.DeleteCollection{
603610
ID: types.MustParse(collectionID),
@@ -621,9 +628,35 @@ func (suite *CollectionServiceTestSuite) TestServer_FlushCollectionCompaction()
621628
suite.NoError(err)
622629
}
623630

631+
func (suite *CollectionServiceTestSuite) TestServer_CheckCollections() {
632+
collectionName := "test_check_collections"
633+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
634+
suite.NoError(err)
635+
636+
request := &coordinatorpb.CheckCollectionsRequest{
637+
CollectionIds: []string{collectionID},
638+
}
639+
640+
// Call the service method
641+
response, err := suite.s.CheckCollections(context.Background(), request)
642+
suite.NoError(err)
643+
644+
suite.NotNil(response.GetDeleted(), "Deleted slice should not be nil.")
645+
suite.Len(response.GetDeleted(), 1)
646+
suite.False(response.GetDeleted()[0])
647+
648+
suite.NotNil(response.GetLogPosition(), "LogPosition slice should not be nil.")
649+
suite.Len(response.GetLogPosition(), 1)
650+
suite.GreaterOrEqual(response.GetLogPosition()[0], int64(0))
651+
652+
// clean up
653+
err = dao.CleanUpTestCollection(suite.db, collectionID)
654+
suite.NoError(err)
655+
}
656+
624657
func (suite *CollectionServiceTestSuite) TestGetCollectionSize() {
625658
collectionName := "collection_service_test_get_collection_size"
626-
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
659+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
627660
suite.NoError(err)
628661

629662
req := coordinatorpb.GetCollectionSizeRequest{
@@ -639,7 +672,7 @@ func (suite *CollectionServiceTestSuite) TestGetCollectionSize() {
639672

640673
func (suite *CollectionServiceTestSuite) TestCountForks() {
641674
collectionName := "collection_service_test_count_forks"
642-
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
675+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
643676
suite.NoError(err)
644677

645678
req := coordinatorpb.CountForksRequest{
@@ -684,7 +717,7 @@ func (suite *CollectionServiceTestSuite) TestCountForks() {
684717

685718
func (suite *CollectionServiceTestSuite) TestFork() {
686719
collectionName := "collection_service_test_forks"
687-
collectionID, err := dao.CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
720+
collectionID, err := dao.CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
688721
suite.NoError(err)
689722
targetCollectionID := types.NewUniqueID()
690723

go/pkg/sysdb/metastore/db/dao/collection_test.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66
"time"
77

8+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dao/daotest"
89
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbcore"
910
"github.com/pingcap/log"
1011
"github.com/stretchr/testify/suite"
@@ -48,7 +49,7 @@ func (suite *CollectionDbTestSuite) TearDownSuite() {
4849
func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
4950
collectionName := "test_collection_get_collections"
5051
dim := int32(128)
51-
collectionID, err := CreateTestCollection(suite.db, collectionName, dim, suite.databaseId, nil)
52+
collectionID, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, dim, suite.databaseId, nil))
5253
suite.NoError(err)
5354

5455
testKey := "test"
@@ -103,7 +104,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
103104
suite.Len(collections, 1)
104105
suite.Equal(collectionID, collections[0].Collection.ID)
105106

106-
collectionID2, err := CreateTestCollection(suite.db, "test_collection_get_collections2", 128, suite.databaseId, nil)
107+
collectionID2, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection("test_collection_get_collections2", 128, suite.databaseId, nil))
107108
suite.NoError(err)
108109

109110
// Test order by. Collections are ordered by create time so collectionID2 should be second
@@ -137,10 +138,10 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
137138
suite.NoError(err)
138139

139140
// Create two collections in the new database.
140-
collectionID3, err := CreateTestCollection(suite.db, "test_collection_get_collections3", 128, DbId, nil)
141+
collectionID3, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection("test_collection_get_collections3", 128, DbId, nil))
141142
suite.NoError(err)
142143

143-
collectionID4, err := CreateTestCollection(suite.db, "test_collection_get_collections4", 128, DbId, nil)
144+
collectionID4, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection("test_collection_get_collections4", 128, DbId, nil))
144145
suite.NoError(err)
145146

146147
// Test count collections
@@ -174,7 +175,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollections() {
174175

175176
func (suite *CollectionDbTestSuite) TestCollectionDb_UpdateLogPositionVersionTotalRecordsAndLogicalSize() {
176177
collectionName := "test_collection_get_collections"
177-
collectionID, _ := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
178+
collectionID, _ := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
178179
ids := []string{collectionID}
179180
// verify default values
180181
collections, err := suite.collectionDb.GetCollections(ids, nil, "", "", nil, nil, false)
@@ -228,9 +229,9 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_SoftDelete() {
228229
// Create 2 collections.
229230
collectionName1 := "test_collection_soft_delete1"
230231
collectionName2 := "test_collection_soft_delete2"
231-
collectionID1, err := CreateTestCollection(suite.db, collectionName1, 128, suite.databaseId, nil)
232+
collectionID1, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName1, 128, suite.databaseId, nil))
232233
suite.NoError(err)
233-
collectionID2, err := CreateTestCollection(suite.db, collectionName2, 128, suite.databaseId, nil)
234+
collectionID2, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName2, 128, suite.databaseId, nil))
234235
suite.NoError(err)
235236

236237
// Soft delete collection 1 by Updating the is_deleted column
@@ -265,7 +266,7 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_SoftDelete() {
265266

266267
func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollectionSize() {
267268
collectionName := "test_collection_get_collection_size"
268-
collectionID, err := CreateTestCollection(suite.db, collectionName, 128, suite.databaseId, nil)
269+
collectionID, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, 128, suite.databaseId, nil))
269270
suite.NoError(err)
270271

271272
total_records_post_compaction, err := suite.collectionDb.GetCollectionSize(collectionID)
@@ -299,17 +300,17 @@ func (suite *CollectionDbTestSuite) TestCollectionDb_GetCollectionByResourceName
299300

300301
collectionName := "test_collection"
301302
dim := int32(128)
302-
collectionID, err := CreateTestCollection(suite.db, collectionName, dim, databaseID, nil)
303+
collectionID, err := CreateTestCollection(suite.db, daotest.NewDefaultTestCollection(collectionName, dim, databaseID, nil))
303304
suite.NoError(err)
304305

305-
collection, err := suite.collectionDb.GetCollectionByResourceName(tenantResourceName, databaseName, collectionName)
306+
collectionResult, err := suite.collectionDb.GetCollectionByResourceName(tenantResourceName, databaseName, collectionName)
306307
suite.NoError(err)
307-
suite.NotNil(collection)
308-
suite.Equal(collectionID, collection.Collection.ID)
309-
suite.Equal(collectionName, *collection.Collection.Name)
310-
suite.Equal(databaseID, collection.Collection.DatabaseID)
311-
suite.Equal(tenantID, collection.TenantID)
312-
suite.Equal(databaseName, collection.DatabaseName)
308+
suite.NotNil(collectionResult)
309+
suite.Equal(collectionID, collectionResult.Collection.ID)
310+
suite.Equal(collectionName, *collectionResult.Collection.Name)
311+
suite.Equal(databaseID, collectionResult.Collection.DatabaseID)
312+
suite.Equal(tenantID, collectionResult.TenantID)
313+
suite.Equal(databaseName, collectionResult.DatabaseName)
313314

314315
nonExistentCollection, err := suite.collectionDb.GetCollectionByResourceName(tenantResourceName, databaseName, "non_existent_collection")
315316
suite.Error(err, "collection not found")
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package daotest
2+
3+
import (
4+
"time"
5+
6+
"github.com/chroma-core/chroma/go/pkg/sysdb/metastore/db/dbmodel"
7+
"github.com/chroma-core/chroma/go/pkg/types"
8+
"github.com/pingcap/log"
9+
"go.uber.org/zap"
10+
)
11+
12+
// defaults for a test collection
13+
var (
14+
defaultConfigurationJsonStr = "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}"
15+
defaultDimension = int32(128)
16+
defaultTotalRecordsPostCompaction = uint64(100)
17+
defaultSizeBytesPostCompaction = uint64(500000)
18+
defaultLastCompactionTimeSecs = uint64(1741037006)
19+
)
20+
21+
// TestDatabaseID is the default database ID for a test collection. It's exported because it's so frequently used
22+
// and can be the same across most tests.
23+
const TestDatabaseID = "test_database"
24+
25+
// TestTenantID is the default tenant ID for a test collection. It's exported because it's so frequently used
26+
// and can be the same across most tests.
27+
const TestTenantID = "test_tenant"
28+
29+
// NewDefaultTestCollection is a "shim" for callsites that existed before dao.CreateTestCollection was refactored to
30+
// take a dbmodel.Collection instead of a subset of its fields as arguments. It should not be used for new tests.
31+
func NewDefaultTestCollection(collectionName string, dimension int32, databaseID string, lineageFileName *string) *dbmodel.Collection {
32+
log.Info("new default test collection", zap.String("collectionName", collectionName), zap.Int32("dimension", dimension), zap.String("databaseID", databaseID))
33+
return NewTestCollection(
34+
TestTenantID,
35+
databaseID,
36+
collectionName,
37+
WithDimension(dimension),
38+
WithTotalRecordsPostCompaction(defaultTotalRecordsPostCompaction),
39+
WithSizeBytesPostCompaction(defaultSizeBytesPostCompaction),
40+
WithLastCompactionTimeSecs(defaultLastCompactionTimeSecs),
41+
WithLineageFileName(lineageFileName),
42+
)
43+
}
44+
45+
// NewTestCollection creates a new test collection with the given name, database ID, and tenant ID.
46+
// Name, databaseID, and tenantID are required, other fields have defaults but can be overridden with
47+
// option function of a similar name.
48+
// Note: collection.CreatedAt is set to the current time, but collection.UpdatedAt is not set.
49+
func NewTestCollection(tenantID, databaseID, collectionName string, options ...TestCollectionOption) *dbmodel.Collection {
50+
log.Info("new test collection", zap.String("tenantID", tenantID), zap.String("databaseID", databaseID), zap.String("collectionName", collectionName))
51+
collectionId := types.NewUniqueID().String()
52+
53+
collection := &dbmodel.Collection{
54+
ID: collectionId,
55+
Name: &collectionName,
56+
ConfigurationJsonStr: &defaultConfigurationJsonStr,
57+
Dimension: &defaultDimension,
58+
DatabaseID: databaseID,
59+
CreatedAt: time.Now(),
60+
Tenant: tenantID,
61+
}
62+
63+
for _, option := range options {
64+
option(collection)
65+
}
66+
67+
return collection
68+
}
69+
70+
type TestCollectionOption func(*dbmodel.Collection)
71+
72+
func WithConfigurationJsonStr(configurationJsonStr string) func(*dbmodel.Collection) {
73+
return func(collection *dbmodel.Collection) {
74+
collection.ConfigurationJsonStr = &configurationJsonStr
75+
}
76+
}
77+
78+
func WithDimension(dimension int32) func(*dbmodel.Collection) {
79+
return func(collection *dbmodel.Collection) {
80+
collection.Dimension = &dimension
81+
}
82+
}
83+
84+
func WithTs(ts int64) func(*dbmodel.Collection) {
85+
return func(collection *dbmodel.Collection) {
86+
collection.Ts = ts
87+
}
88+
}
89+
90+
func WithIsDeleted(isDeleted bool) func(*dbmodel.Collection) {
91+
return func(collection *dbmodel.Collection) {
92+
collection.IsDeleted = isDeleted
93+
}
94+
}
95+
96+
func WithCreatedAt(createdAt time.Time) func(*dbmodel.Collection) {
97+
return func(collection *dbmodel.Collection) {
98+
collection.CreatedAt = createdAt
99+
}
100+
}
101+
102+
func WithUpdatedAt(updatedAt time.Time) func(*dbmodel.Collection) {
103+
return func(collection *dbmodel.Collection) {
104+
collection.UpdatedAt = updatedAt
105+
}
106+
}
107+
108+
func WithLogPosition(logPosition int64) func(*dbmodel.Collection) {
109+
return func(collection *dbmodel.Collection) {
110+
collection.LogPosition = logPosition
111+
}
112+
}
113+
114+
func WithVersion(version int32) func(*dbmodel.Collection) {
115+
return func(collection *dbmodel.Collection) {
116+
collection.Version = version
117+
}
118+
}
119+
120+
func WithVersionFileName(versionFileName string) func(*dbmodel.Collection) {
121+
return func(collection *dbmodel.Collection) {
122+
collection.VersionFileName = versionFileName
123+
}
124+
}
125+
126+
func WithRootCollectionID(rootCollectionId string) func(*dbmodel.Collection) {
127+
return func(collection *dbmodel.Collection) {
128+
collection.RootCollectionId = &rootCollectionId
129+
}
130+
}
131+
132+
func WithLineageFileName(lineageFileName *string) func(*dbmodel.Collection) {
133+
return func(collection *dbmodel.Collection) {
134+
collection.LineageFileName = lineageFileName
135+
}
136+
}
137+
138+
func WithTotalRecordsPostCompaction(totalRecordsPostCompaction uint64) func(*dbmodel.Collection) {
139+
return func(collection *dbmodel.Collection) {
140+
collection.TotalRecordsPostCompaction = totalRecordsPostCompaction
141+
}
142+
}
143+
144+
func WithSizeBytesPostCompaction(sizeBytesPostCompaction uint64) func(*dbmodel.Collection) {
145+
return func(collection *dbmodel.Collection) {
146+
collection.SizeBytesPostCompaction = sizeBytesPostCompaction
147+
}
148+
}
149+
150+
func WithLastCompactionTimeSecs(lastCompactionTimeSecs uint64) func(*dbmodel.Collection) {
151+
return func(collection *dbmodel.Collection) {
152+
collection.LastCompactionTimeSecs = lastCompactionTimeSecs
153+
}
154+
}
155+
156+
func WithNumVersions(numVersions uint32) func(*dbmodel.Collection) {
157+
return func(collection *dbmodel.Collection) {
158+
collection.NumVersions = numVersions
159+
}
160+
}
161+
162+
func WithOldestVersionTs(oldestVersionTs time.Time) func(*dbmodel.Collection) {
163+
return func(collection *dbmodel.Collection) {
164+
collection.OldestVersionTs = oldestVersionTs
165+
}
166+
}

0 commit comments

Comments
 (0)