@@ -23,11 +23,17 @@ func TestBasicSessionStore(t *testing.T) {
23
23
_ = db .Close ()
24
24
})
25
25
26
- // Create a few sessions.
27
- s1 := newSession (t , db , clock , "session 1" , nil )
28
- s2 := newSession (t , db , clock , "session 2" , nil )
29
- s3 := newSession (t , db , clock , "session 3" , nil )
30
- s4 := newSession (t , db , clock , "session 4" , nil )
26
+ // Create a few sessions. We increment the time by one second between
27
+ // each session to ensure that the created at time is unique and hence
28
+ // that the ListSessions method returns the sessions in a deterministic
29
+ // order.
30
+ s1 := newSession (t , db , clock , "session 1" )
31
+ clock .SetTime (testTime .Add (time .Second ))
32
+ s2 := newSession (t , db , clock , "session 2" )
33
+ clock .SetTime (testTime .Add (2 * time .Second ))
34
+ s3 := newSession (t , db , clock , "session 3" , withType (TypeAutopilot ))
35
+ clock .SetTime (testTime .Add (3 * time .Second ))
36
+ s4 := newSession (t , db , clock , "session 4" )
31
37
32
38
// Persist session 1. This should now succeed.
33
39
require .NoError (t , db .CreateSession (s1 ))
@@ -50,6 +56,22 @@ func TestBasicSessionStore(t *testing.T) {
50
56
require .NoError (t , db .CreateSession (s2 ))
51
57
require .NoError (t , db .CreateSession (s3 ))
52
58
59
+ // Test the ListSessionsByType method.
60
+ sessions , err := db .ListSessionsByType (TypeMacaroonAdmin )
61
+ require .NoError (t , err )
62
+ require .Equal (t , 2 , len (sessions ))
63
+ assertEqualSessions (t , s1 , sessions [0 ])
64
+ assertEqualSessions (t , s2 , sessions [1 ])
65
+
66
+ sessions , err = db .ListSessionsByType (TypeAutopilot )
67
+ require .NoError (t , err )
68
+ require .Equal (t , 1 , len (sessions ))
69
+ assertEqualSessions (t , s3 , sessions [0 ])
70
+
71
+ sessions , err = db .ListSessionsByType (TypeMacaroonReadonly )
72
+ require .NoError (t , err )
73
+ require .Empty (t , sessions )
74
+
53
75
// Ensure that we can retrieve each session by both its local pub key
54
76
// and by its ID.
55
77
for _ , s := range []* Session {s1 , s2 , s3 } {
@@ -85,9 +107,44 @@ func TestBasicSessionStore(t *testing.T) {
85
107
86
108
// Now revoke the session and assert that the state is revoked.
87
109
require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
88
- session1 , err = db .GetSession (s1 .LocalPublicKey )
110
+ s1 , err = db .GetSession (s1 .LocalPublicKey )
111
+ require .NoError (t , err )
112
+ require .Equal (t , s1 .State , StateRevoked )
113
+
114
+ // Test that ListAllSessions works.
115
+ sessions , err = db .ListAllSessions ()
116
+ require .NoError (t , err )
117
+ require .Equal (t , 3 , len (sessions ))
118
+ assertEqualSessions (t , s1 , sessions [0 ])
119
+ assertEqualSessions (t , s2 , sessions [1 ])
120
+ assertEqualSessions (t , s3 , sessions [2 ])
121
+
122
+ // Test that ListSessionsByState works.
123
+ sessions , err = db .ListSessionsByState (StateRevoked )
124
+ require .NoError (t , err )
125
+ require .Equal (t , 1 , len (sessions ))
126
+ assertEqualSessions (t , s1 , sessions [0 ])
127
+
128
+ sessions , err = db .ListSessionsByState (StateCreated )
129
+ require .NoError (t , err )
130
+ require .Equal (t , 2 , len (sessions ))
131
+ assertEqualSessions (t , s2 , sessions [0 ])
132
+ assertEqualSessions (t , s3 , sessions [1 ])
133
+
134
+ sessions , err = db .ListSessionsByState (StateCreated , StateRevoked )
135
+ require .NoError (t , err )
136
+ require .Equal (t , 3 , len (sessions ))
137
+ assertEqualSessions (t , s1 , sessions [0 ])
138
+ assertEqualSessions (t , s2 , sessions [1 ])
139
+ assertEqualSessions (t , s3 , sessions [2 ])
140
+
141
+ sessions , err = db .ListSessionsByState ()
89
142
require .NoError (t , err )
90
- require .Equal (t , session1 .State , StateRevoked )
143
+ require .Empty (t , sessions )
144
+
145
+ sessions , err = db .ListSessionsByState (StateInUse )
146
+ require .NoError (t , err )
147
+ require .Empty (t , sessions )
91
148
}
92
149
93
150
// TestLinkingSessions tests that session linking works as expected.
@@ -101,10 +158,10 @@ func TestLinkingSessions(t *testing.T) {
101
158
})
102
159
103
160
// Create a new session with no previous link.
104
- s1 := newSession (t , db , clock , "session 1" , nil )
161
+ s1 := newSession (t , db , clock , "session 1" )
105
162
106
163
// Create another session and link it to the first.
107
- s2 := newSession (t , db , clock , "session 2" , & s1 .GroupID )
164
+ s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID ( & s1 .GroupID ) )
108
165
109
166
// Try to persist the second session and assert that it fails due to the
110
167
// linked session not existing in the DB yet.
@@ -141,9 +198,9 @@ func TestLinkedSessions(t *testing.T) {
141
198
// after are all linked to the prior one. All these sessions belong to
142
199
// the same group. The group ID is equivalent to the session ID of the
143
200
// first session.
144
- s1 := newSession (t , db , clock , "session 1" , nil )
145
- s2 := newSession (t , db , clock , "session 2" , & s1 .GroupID )
146
- s3 := newSession (t , db , clock , "session 3" , & s2 .GroupID )
201
+ s1 := newSession (t , db , clock , "session 1" )
202
+ s2 := newSession (t , db , clock , "session 2" , withLinkedGroupID ( & s1 .GroupID ) )
203
+ s3 := newSession (t , db , clock , "session 3" , withLinkedGroupID ( & s2 .GroupID ) )
147
204
148
205
// Persist the sessions.
149
206
require .NoError (t , db .CreateSession (s1 ))
@@ -169,8 +226,8 @@ func TestLinkedSessions(t *testing.T) {
169
226
170
227
// To ensure that different groups don't interfere with each other,
171
228
// let's add another set of linked sessions not linked to the first.
172
- s4 := newSession (t , db , clock , "session 4" , nil )
173
- s5 := newSession (t , db , clock , "session 5" , & s4 .GroupID )
229
+ s4 := newSession (t , db , clock , "session 4" )
230
+ s5 := newSession (t , db , clock , "session 5" , withLinkedGroupID ( & s4 .GroupID ) )
174
231
175
232
require .NotEqual (t , s4 .GroupID , s1 .GroupID )
176
233
@@ -209,7 +266,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
209
266
// function is checked correctly.
210
267
211
268
// Add a new session to the DB.
212
- s1 := newSession (t , db , clock , "label 1" , nil )
269
+ s1 := newSession (t , db , clock , "label 1" )
213
270
require .NoError (t , db .CreateSession (s1 ))
214
271
215
272
// Check that the group passes against an appropriate predicate.
@@ -234,7 +291,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
234
291
require .NoError (t , db .RevokeSession (s1 .LocalPublicKey ))
235
292
236
293
// Add a new session to the same group as the first one.
237
- s2 := newSession (t , db , clock , "label 2" , & s1 .GroupID )
294
+ s2 := newSession (t , db , clock , "label 2" , withLinkedGroupID ( & s1 .GroupID ) )
238
295
require .NoError (t , db .CreateSession (s2 ))
239
296
240
297
// Check that the group passes against an appropriate predicate.
@@ -256,7 +313,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
256
313
require .False (t , ok )
257
314
258
315
// Add a new session that is not linked to the first one.
259
- s3 := newSession (t , db , clock , "completely different" , nil )
316
+ s3 := newSession (t , db , clock , "completely different" )
260
317
require .NoError (t , db .CreateSession (s3 ))
261
318
262
319
// Ensure that the first group is unaffected.
@@ -286,8 +343,24 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
286
343
require .True (t , ok )
287
344
}
288
345
346
+ // testSessionModifier is a functional option that can be used to modify the
347
+ // default test session created by newSession.
348
+ type testSessionModifier func (* Session )
349
+
350
+ func withLinkedGroupID (groupID * ID ) testSessionModifier {
351
+ return func (s * Session ) {
352
+ s .GroupID = * groupID
353
+ }
354
+ }
355
+
356
+ func withType (t Type ) testSessionModifier {
357
+ return func (s * Session ) {
358
+ s .Type = t
359
+ }
360
+ }
361
+
289
362
func newSession (t * testing.T , db Store , clock clock.Clock , label string ,
290
- linkedGroupID * ID ) * Session {
363
+ mods ... testSessionModifier ) * Session {
291
364
292
365
id , priv , err := db .GetUnusedIDAndKeyPair ()
293
366
require .NoError (t , err )
@@ -296,11 +369,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string,
296
369
id , priv , label , TypeMacaroonAdmin ,
297
370
clock .Now (),
298
371
time .Date (99999 , 1 , 1 , 0 , 0 , 0 , 0 , time .UTC ),
299
- "foo.bar.baz:1234" , true , nil , nil , nil , true , linkedGroupID ,
372
+ "foo.bar.baz:1234" , true , nil , nil , nil , true , nil ,
300
373
[]PrivacyFlag {ClearPubkeys },
301
374
)
302
375
require .NoError (t , err )
303
376
377
+ for _ , mod := range mods {
378
+ mod (session )
379
+ }
380
+
304
381
return session
305
382
}
306
383
0 commit comments