Skip to content

Commit 26d3b12

Browse files
author
Divjot Arora
committed
Use testing library for sessions tests.
GODRIVER-1285 Change-Id: I00422dab5ebc5968c0eb32f335aceb08fa277502
1 parent 2e2efdf commit 26d3b12

File tree

5 files changed

+494
-707
lines changed

5 files changed

+494
-707
lines changed

.errcheck-excludes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
(go.mongodb.org/mongo-driver/x/network/wiremessage.ReadWriteCloser).Close
1010
(*go.mongodb.org/mongo-driver/mongo.Cursor).Close
1111
(*go.mongodb.org/mongo-driver/mongo.ChangeStream).Close
12+
(*go.mongodb.org/mongo-driver/mongo.Client).Disconnect
1213
(net.Conn).Close
1314
encoding/pem.Encode
1415
fmt.Fprintf

mongo/database.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ func (db *Database) Drop(ctx context.Context) error {
212212

213213
sess := sessionFromContext(ctx)
214214
if sess == nil && db.client.sessionPool != nil {
215-
sess, err := session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit)
215+
var err error
216+
sess, err = session.NewClientSession(db.client.sessionPool, db.client.id, session.Implicit)
216217
if err != nil {
217218
return err
218219
}

mongo/integration/sessions_test.go

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package integration
8+
9+
import (
10+
"bytes"
11+
"reflect"
12+
"testing"
13+
"time"
14+
15+
"go.mongodb.org/mongo-driver/bson"
16+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
17+
"go.mongodb.org/mongo-driver/mongo"
18+
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
19+
"go.mongodb.org/mongo-driver/mongo/options"
20+
"go.mongodb.org/mongo-driver/mongo/readpref"
21+
"go.mongodb.org/mongo-driver/x/bsonx"
22+
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
23+
)
24+
25+
func TestSessionPool(t *testing.T) {
26+
mt := mtest.New(t, mtest.NewOptions().MinServerVersion("3.6").CreateClient(false))
27+
defer mt.Close()
28+
29+
mt.Run("pool LIFO", func(mt *mtest.T) {
30+
aSess, err := mt.Client.StartSession()
31+
assert.Nil(mt, err, "StartSession error: %v", err)
32+
bSess, err := mt.Client.StartSession()
33+
assert.Nil(mt, err, "StartSession error: %v", err)
34+
35+
// end the sessions to return them to the pool
36+
aSess.EndSession(mtest.Background)
37+
bSess.EndSession(mtest.Background)
38+
39+
firstSess, err := mt.Client.StartSession()
40+
assert.Nil(mt, err, "StartSession error: %v", err)
41+
defer firstSess.EndSession(mtest.Background)
42+
want := getSessionID(mt, bSess)
43+
got := getSessionID(mt, firstSess)
44+
assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
45+
46+
secondSess, err := mt.Client.StartSession()
47+
assert.Nil(mt, err, "StartSession error: %v", err)
48+
defer secondSess.EndSession(mtest.Background)
49+
want = getSessionID(mt, aSess)
50+
got = getSessionID(mt, secondSess)
51+
assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
52+
})
53+
}
54+
55+
func TestSessions(t *testing.T) {
56+
mtOpts := mtest.NewOptions().MinServerVersion("3.6").Topologies(mtest.ReplicaSet, mtest.Sharded).
57+
CreateClient(false)
58+
mt := mtest.New(t, mtOpts)
59+
60+
clusterTimeOpts := mtest.NewOptions().ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
61+
CreateClient(false)
62+
mt.RunOpts("cluster time", clusterTimeOpts, func(mt *mtest.T) {
63+
// $clusterTime included in commands
64+
65+
serverStatus := sessionFunction{"server status", "database", "RunCommand", []interface{}{bson.D{{"serverStatus", 1}}}}
66+
insert := sessionFunction{"insert one", "collection", "InsertOne", []interface{}{bson.D{{"x", 1}}}}
67+
agg := sessionFunction{"aggregate", "collection", "Aggregate", []interface{}{mongo.Pipeline{}}}
68+
find := sessionFunction{"find", "collection", "Find", []interface{}{bson.D{}}}
69+
70+
sessionFunctions := []sessionFunction{serverStatus, insert, agg, find}
71+
for _, sf := range sessionFunctions {
72+
mt.Run(sf.name, func(mt *mtest.T) {
73+
err := sf.execute(mt, nil)
74+
assert.Nil(mt, err, "%v error: %v", sf.name, err)
75+
76+
// assert $clusterTime was sent to server
77+
started := mt.GetStartedEvent()
78+
assert.NotNil(mt, started, "expected started event, got nil")
79+
_, err = started.Command.LookupErr("$clusterTime")
80+
assert.Nil(mt, err, "$clusterTime not sent")
81+
82+
// record response cluster time
83+
succeeded := mt.GetSucceededEvent()
84+
assert.NotNil(mt, succeeded, "expected succeeded event, got nil")
85+
replyClusterTimeVal, err := succeeded.Reply.LookupErr("$clusterTime")
86+
assert.Nil(mt, err, "$clusterTime not found in response")
87+
88+
// call function again
89+
err = sf.execute(mt, nil)
90+
assert.Nil(mt, err, "%v error: %v", sf.name, err)
91+
92+
// find cluster time sent to server and assert it is the same as the one in the previous response
93+
sentClusterTimeVal, err := mt.GetStartedEvent().Command.LookupErr("$clusterTime")
94+
assert.Nil(mt, err, "$clusterTime not sent")
95+
replyClusterTimeDoc := replyClusterTimeVal.Document()
96+
sentClusterTimeDoc := sentClusterTimeVal.Document()
97+
assert.Equal(mt, replyClusterTimeDoc, sentClusterTimeDoc,
98+
"expected cluster time %v, got %v", replyClusterTimeDoc, sentClusterTimeDoc)
99+
})
100+
}
101+
})
102+
mt.RunOpts("explicit implicit session arguments", noClientOpts, func(mt *mtest.T) {
103+
// lsid is included in commands with explicit and implicit sessions
104+
105+
sessionFunctions := createFunctionsSlice()
106+
for _, sf := range sessionFunctions {
107+
mt.Run(sf.name, func(mt *mtest.T) {
108+
// explicit session
109+
sess, err := mt.Client.StartSession()
110+
assert.Nil(mt, err, "StartSession error: %v", err)
111+
defer sess.EndSession(mtest.Background)
112+
mt.ClearEvents()
113+
114+
_ = sf.execute(mt, sess) // don't check error because we only care about lsid
115+
_, wantID := getSessionID(mt, sess).Lookup("id").Binary()
116+
gotID := extractSentSessionID(mt)
117+
assert.True(mt, bytes.Equal(wantID, gotID), "expected session ID %v, got %v", wantID, gotID)
118+
119+
// implicit session
120+
_ = sf.execute(mt, nil)
121+
gotID = extractSentSessionID(mt)
122+
assert.NotNil(mt, gotID, "expected lsid, got nil")
123+
})
124+
}
125+
})
126+
mt.Run("wrong client", func(mt *mtest.T) {
127+
// a session can only be used in commands associated with the client that created it
128+
129+
sessionFunctions := createFunctionsSlice()
130+
sess, err := mt.Client.StartSession()
131+
assert.Nil(mt, err, "StartSession error: %v", err)
132+
133+
for _, sf := range sessionFunctions {
134+
mt.Run(sf.name, func(mt *mtest.T) {
135+
err = sf.execute(mt, sess)
136+
assert.Equal(mt, mongo.ErrWrongClient, err, "expected error %v, got %v", mongo.ErrWrongClient, err)
137+
})
138+
}
139+
})
140+
mt.RunOpts("ended session", noClientOpts, func(mt *mtest.T) {
141+
// an ended session cannot be used in commands
142+
143+
sessionFunctions := createFunctionsSlice()
144+
for _, sf := range sessionFunctions {
145+
mt.Run(sf.name, func(mt *mtest.T) {
146+
sess, err := mt.Client.StartSession()
147+
assert.Nil(mt, err, "StartSession error: %v", err)
148+
sess.EndSession(mtest.Background)
149+
150+
err = sf.execute(mt, sess)
151+
assert.Equal(mt, session.ErrSessionEnded, err, "expected error %v, got %v", session.ErrSessionEnded, err)
152+
})
153+
}
154+
})
155+
mt.Run("implicit session returned", func(mt *mtest.T) {
156+
// implicit sessions are returned to the server session pool
157+
158+
doc := bson.D{{"x", 1}}
159+
_, err := mt.Coll.InsertOne(mtest.Background, doc)
160+
assert.Nil(mt, err, "InsertOne error: %v", err)
161+
_, err = mt.Coll.InsertOne(mtest.Background, doc)
162+
assert.Nil(mt, err, "InsertOne error: %v", err)
163+
164+
// create a cursor that will hold onto an implicit session and record the sent session ID
165+
mt.ClearEvents()
166+
cursor, err := mt.Coll.Find(mtest.Background, bson.D{})
167+
assert.Nil(mt, err, "Find error: %v", err)
168+
findID := extractSentSessionID(mt)
169+
assert.True(mt, cursor.Next(mtest.Background), "expected Next true, got false")
170+
171+
// execute another operation and verify the find session ID was reused
172+
_, err = mt.Coll.DeleteOne(mtest.Background, bson.D{})
173+
assert.Nil(mt, err, "DeleteOne error: %v", err)
174+
deleteID := extractSentSessionID(mt)
175+
assert.Equal(mt, findID, deleteID, "expected session ID %v, got %v", findID, deleteID)
176+
})
177+
mt.Run("implicit session returned from getMore", func(mt *mtest.T) {
178+
// Client-side cursor that exhausts the results after a getMore immediately returns the implicit session to the pool.
179+
180+
var docs []interface{}
181+
for i := 0; i < 5; i++ {
182+
docs = append(docs, bson.D{{"x", i}})
183+
}
184+
_, err := mt.Coll.InsertMany(mtest.Background, docs)
185+
assert.Nil(mt, err, "InsertMany error: %v", err)
186+
187+
// run a find that will hold onto the implicit session and record the session ID
188+
mt.ClearEvents()
189+
cursor, err := mt.Coll.Find(mtest.Background, bson.D{}, options.Find().SetBatchSize(3))
190+
assert.Nil(mt, err, "Find error: %v", err)
191+
findID := extractSentSessionID(mt)
192+
193+
// iterate past 4 documents, forcing a getMore. session should be returned to pool after getMore
194+
for i := 0; i < 4; i++ {
195+
assert.True(mt, cursor.Next(mtest.Background), "Next returned false on iteration %v", i)
196+
}
197+
198+
// execute another operation and verify the find session ID was reused
199+
_, err = mt.Coll.DeleteOne(mtest.Background, bson.D{})
200+
assert.Nil(mt, err, "DeleteOne error: %v", err)
201+
deleteID := extractSentSessionID(mt)
202+
assert.Equal(mt, findID, deleteID, "expected session ID %v, got %v", findID, deleteID)
203+
})
204+
mt.RunOpts("find and getMore use same ID", noClientOpts, func(mt *mtest.T) {
205+
testCases := []struct {
206+
name string
207+
rp *readpref.ReadPref
208+
topos []mtest.TopologyKind // if nil, all will be used
209+
}{
210+
{"primary", readpref.Primary(), nil},
211+
{"primaryPreferred", readpref.PrimaryPreferred(), nil},
212+
{"secondary", readpref.Secondary(), []mtest.TopologyKind{mtest.ReplicaSet}},
213+
{"secondaryPreferred", readpref.SecondaryPreferred(), nil},
214+
{"nearest", readpref.Nearest(), nil},
215+
}
216+
for _, tc := range testCases {
217+
clientOpts := options.Client().SetReadPreference(tc.rp).SetWriteConcern(mtest.MajorityWc)
218+
mt.RunOpts(tc.name, mtest.NewOptions().ClientOptions(clientOpts).Topologies(tc.topos...), func(mt *mtest.T) {
219+
var docs []interface{}
220+
for i := 0; i < 3; i++ {
221+
docs = append(docs, bson.D{{"x", i}})
222+
}
223+
_, err := mt.Coll.InsertMany(mtest.Background, docs)
224+
assert.Nil(mt, err, "InsertMany error: %v", err)
225+
226+
// run a find that will hold onto an implicit session and record the session ID
227+
mt.ClearEvents()
228+
cursor, err := mt.Coll.Find(mtest.Background, bson.D{}, options.Find().SetBatchSize(2))
229+
assert.Nil(mt, err, "Find error: %v", err)
230+
findID := extractSentSessionID(mt)
231+
assert.NotNil(mt, findID, "expected session ID for find, got nil")
232+
233+
// iterate over all documents and record the session ID of the getMore
234+
for i := 0; i < 3; i++ {
235+
assert.True(mt, cursor.Next(mtest.Background), "Next returned false on iteration %v", i)
236+
}
237+
getMoreID := extractSentSessionID(mt)
238+
assert.Equal(mt, findID, getMoreID, "expected session ID %v, got %v", findID, getMoreID)
239+
})
240+
}
241+
})
242+
}
243+
244+
type sessionFunction struct {
245+
name string
246+
target string
247+
fnName string
248+
params []interface{} // should not include context
249+
}
250+
251+
func (sf sessionFunction) execute(mt *mtest.T, sess mongo.Session) error {
252+
var target reflect.Value
253+
switch sf.target {
254+
case "client":
255+
target = reflect.ValueOf(mt.Client)
256+
case "database":
257+
// use a different database for drops because any executed after the drop will get "database not found"
258+
// errors on sharded clusters
259+
if sf.name != "drop database" {
260+
target = reflect.ValueOf(mt.DB)
261+
break
262+
}
263+
target = reflect.ValueOf(mt.Client.Database("sessionsTestsDropDatabase"))
264+
case "collection":
265+
target = reflect.ValueOf(mt.Coll)
266+
case "indexView":
267+
target = reflect.ValueOf(mt.Coll.Indexes())
268+
default:
269+
mt.Fatalf("unrecognized target: %v", sf.target)
270+
}
271+
272+
fn := target.MethodByName(sf.fnName)
273+
paramsValues := interfaceSliceToValueSlice(sf.params)
274+
275+
if sess != nil {
276+
return mongo.WithSession(mtest.Background, sess, func(sc mongo.SessionContext) error {
277+
valueArgs := []reflect.Value{reflect.ValueOf(sc)}
278+
valueArgs = append(valueArgs, paramsValues...)
279+
returnValues := fn.Call(valueArgs)
280+
return extractReturnError(returnValues)
281+
})
282+
}
283+
valueArgs := []reflect.Value{reflect.ValueOf(mtest.Background)}
284+
valueArgs = append(valueArgs, paramsValues...)
285+
returnValues := fn.Call(valueArgs)
286+
return extractReturnError(returnValues)
287+
}
288+
289+
func createFunctionsSlice() []sessionFunction {
290+
insertManyDocs := []interface{}{bson.D{{"x", 1}}}
291+
fooIndex := mongo.IndexModel{
292+
Keys: bson.D{{"foo", -1}},
293+
Options: options.Index().SetName("fooIndex"),
294+
}
295+
manyIndexes := []mongo.IndexModel{fooIndex}
296+
updateDoc := bson.D{{"$inc", bson.D{{"x", 1}}}}
297+
298+
return []sessionFunction{
299+
{"list databases", "client", "ListDatabases", []interface{}{bson.D{}}},
300+
{"insert one", "collection", "InsertOne", []interface{}{bson.D{{"x", 1}}}},
301+
{"insert many", "collection", "InsertMany", []interface{}{insertManyDocs}},
302+
{"delete one", "collection", "DeleteOne", []interface{}{bson.D{}}},
303+
{"delete many", "collection", "DeleteMany", []interface{}{bson.D{}}},
304+
{"update one", "collection", "UpdateOne", []interface{}{bson.D{}, updateDoc}},
305+
{"update many", "collection", "UpdateMany", []interface{}{bson.D{}, updateDoc}},
306+
{"replace one", "collection", "ReplaceOne", []interface{}{bson.D{}, bson.D{}}},
307+
{"aggregate", "collection", "Aggregate", []interface{}{mongo.Pipeline{}}},
308+
{"estimated document count", "collection", "EstimatedDocumentCount", nil},
309+
{"distinct", "collection", "Distinct", []interface{}{"field", bson.D{}}},
310+
{"find", "collection", "Find", []interface{}{bson.D{}}},
311+
{"find one and delete", "collection", "FindOneAndDelete", []interface{}{bson.D{}}},
312+
{"find one and replace", "collection", "FindOneAndReplace", []interface{}{bson.D{}, bson.D{}}},
313+
{"find one and update", "collection", "FindOneAndUpdate", []interface{}{bson.D{}, updateDoc}},
314+
{"drop collection", "collection", "Drop", nil},
315+
{"list collections", "database", "ListCollections", []interface{}{bson.D{}}},
316+
{"drop database", "database", "Drop", nil},
317+
{"create one index", "indexView", "CreateOne", []interface{}{fooIndex}},
318+
{"create many indexes", "indexView", "CreateMany", []interface{}{manyIndexes}},
319+
{"drop one index", "indexView", "DropOne", []interface{}{"barIndex"}},
320+
{"drop all indexes", "indexView", "DropAll", nil},
321+
{"list indexes", "indexView", "List", nil},
322+
}
323+
}
324+
325+
func sessionIDsEqual(mt *mtest.T, id1, id2 bsonx.Doc) bool {
326+
first, err := id1.LookupErr("id")
327+
assert.Nil(mt, err, "id not found in document %v", id1)
328+
second, err := id2.LookupErr("id")
329+
assert.Nil(mt, err, "id not found in document %v", id2)
330+
331+
_, firstUUID := first.Binary()
332+
_, secondUUID := second.Binary()
333+
return bytes.Equal(firstUUID, secondUUID)
334+
}
335+
336+
func interfaceSliceToValueSlice(args []interface{}) []reflect.Value {
337+
vals := make([]reflect.Value, 0, len(args))
338+
for _, arg := range args {
339+
vals = append(vals, reflect.ValueOf(arg))
340+
}
341+
return vals
342+
}
343+
344+
func extractReturnError(returnValues []reflect.Value) error {
345+
errVal := returnValues[len(returnValues)-1]
346+
switch converted := errVal.Interface().(type) {
347+
case error:
348+
return converted
349+
case *mongo.SingleResult:
350+
return converted.Err()
351+
default:
352+
return nil
353+
}
354+
}
355+
356+
func extractSentSessionID(mt *mtest.T) []byte {
357+
lsid, err := mt.GetStartedEvent().Command.LookupErr("lsid")
358+
if err != nil {
359+
return nil
360+
}
361+
362+
_, data := lsid.Document().Lookup("id").Binary()
363+
return data
364+
}

0 commit comments

Comments
 (0)