Skip to content

Commit c749126

Browse files
authored
Merge pull request #970 from ellemouton/sql12Sessions4
[sql-12] sessions: make ListSession methods SQL ready
2 parents d56ba68 + 14cb0be commit c749126

File tree

4 files changed

+156
-46
lines changed

4 files changed

+156
-46
lines changed

session/interface.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,15 @@ type Store interface {
161161
// GetSession fetches the session with the given key.
162162
GetSession(key *btcec.PublicKey) (*Session, error)
163163

164-
// ListSessions returns all sessions currently known to the store.
165-
ListSessions(filterFn func(s *Session) bool) ([]*Session, error)
164+
// ListAllSessions returns all sessions currently known to the store.
165+
ListAllSessions() ([]*Session, error)
166+
167+
// ListSessionsByType returns all sessions of the given type.
168+
ListSessionsByType(t Type) ([]*Session, error)
169+
170+
// ListSessionsByState returns all sessions currently known to the store
171+
// that are in the given states.
172+
ListSessionsByState(...State) ([]*Session, error)
166173

167174
// RevokeSession updates the state of the session with the given local
168175
// public key to be revoked.

session/kvdb_store.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"os"
99
"path/filepath"
10+
"sort"
1011
"time"
1112

1213
"github.com/btcsuite/btcd/btcec/v2"
@@ -363,10 +364,46 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
363364
return session, nil
364365
}
365366

366-
// ListSessions returns all sessions currently known to the store.
367+
// ListAllSessions returns all sessions currently known to the store.
367368
//
368369
// NOTE: this is part of the Store interface.
369-
func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
370+
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
371+
return db.listSessions(func(s *Session) bool {
372+
return true
373+
})
374+
}
375+
376+
// ListSessionsByType returns all sessions currently known to the store that
377+
// have the given type.
378+
//
379+
// NOTE: this is part of the Store interface.
380+
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
381+
return db.listSessions(func(s *Session) bool {
382+
return s.Type == t
383+
})
384+
}
385+
386+
// ListSessionsByState returns all sessions currently known to the store that
387+
// are in the given states.
388+
//
389+
// NOTE: this is part of the Store interface.
390+
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
391+
return db.listSessions(func(s *Session) bool {
392+
for _, state := range states {
393+
if s.State == state {
394+
return true
395+
}
396+
}
397+
398+
return false
399+
})
400+
}
401+
402+
// listSessions returns all sessions currently known to the store that pass the
403+
// given filter function.
404+
func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
405+
error) {
406+
370407
var sessions []*Session
371408
err := db.View(func(tx *bbolt.Tx) error {
372409
sessionBucket, err := getBucket(tx, sessionBucketKey)
@@ -399,6 +436,11 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e
399436
return nil, err
400437
}
401438

439+
// Make sure to sort the sessions by creation time.
440+
sort.Slice(sessions, func(i, j int) bool {
441+
return sessions[i].CreatedAt.Before(sessions[j].CreatedAt)
442+
})
443+
402444
return sessions, nil
403445
}
404446

session/store_test.go

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,17 @@ func TestBasicSessionStore(t *testing.T) {
2323
_ = db.Close()
2424
})
2525

26-
// Create a few sessions.
27-
s1 := newSession(t, db, clock, "session 1", nil)
28-
s2 := newSession(t, db, clock, "session 2", nil)
29-
s3 := newSession(t, db, clock, "session 3", nil)
30-
s4 := newSession(t, db, clock, "session 4", nil)
26+
// Create a few sessions. We increment the time by one second between
27+
// each session to ensure that the created at time is unique and hence
28+
// that the ListSessions method returns the sessions in a deterministic
29+
// order.
30+
s1 := newSession(t, db, clock, "session 1")
31+
clock.SetTime(testTime.Add(time.Second))
32+
s2 := newSession(t, db, clock, "session 2")
33+
clock.SetTime(testTime.Add(2 * time.Second))
34+
s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot))
35+
clock.SetTime(testTime.Add(3 * time.Second))
36+
s4 := newSession(t, db, clock, "session 4")
3137

3238
// Persist session 1. This should now succeed.
3339
require.NoError(t, db.CreateSession(s1))
@@ -50,6 +56,22 @@ func TestBasicSessionStore(t *testing.T) {
5056
require.NoError(t, db.CreateSession(s2))
5157
require.NoError(t, db.CreateSession(s3))
5258

59+
// Test the ListSessionsByType method.
60+
sessions, err := db.ListSessionsByType(TypeMacaroonAdmin)
61+
require.NoError(t, err)
62+
require.Equal(t, 2, len(sessions))
63+
assertEqualSessions(t, s1, sessions[0])
64+
assertEqualSessions(t, s2, sessions[1])
65+
66+
sessions, err = db.ListSessionsByType(TypeAutopilot)
67+
require.NoError(t, err)
68+
require.Equal(t, 1, len(sessions))
69+
assertEqualSessions(t, s3, sessions[0])
70+
71+
sessions, err = db.ListSessionsByType(TypeMacaroonReadonly)
72+
require.NoError(t, err)
73+
require.Empty(t, sessions)
74+
5375
// Ensure that we can retrieve each session by both its local pub key
5476
// and by its ID.
5577
for _, s := range []*Session{s1, s2, s3} {
@@ -85,9 +107,44 @@ func TestBasicSessionStore(t *testing.T) {
85107

86108
// Now revoke the session and assert that the state is revoked.
87109
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
88-
session1, err = db.GetSession(s1.LocalPublicKey)
110+
s1, err = db.GetSession(s1.LocalPublicKey)
111+
require.NoError(t, err)
112+
require.Equal(t, s1.State, StateRevoked)
113+
114+
// Test that ListAllSessions works.
115+
sessions, err = db.ListAllSessions()
116+
require.NoError(t, err)
117+
require.Equal(t, 3, len(sessions))
118+
assertEqualSessions(t, s1, sessions[0])
119+
assertEqualSessions(t, s2, sessions[1])
120+
assertEqualSessions(t, s3, sessions[2])
121+
122+
// Test that ListSessionsByState works.
123+
sessions, err = db.ListSessionsByState(StateRevoked)
124+
require.NoError(t, err)
125+
require.Equal(t, 1, len(sessions))
126+
assertEqualSessions(t, s1, sessions[0])
127+
128+
sessions, err = db.ListSessionsByState(StateCreated)
129+
require.NoError(t, err)
130+
require.Equal(t, 2, len(sessions))
131+
assertEqualSessions(t, s2, sessions[0])
132+
assertEqualSessions(t, s3, sessions[1])
133+
134+
sessions, err = db.ListSessionsByState(StateCreated, StateRevoked)
135+
require.NoError(t, err)
136+
require.Equal(t, 3, len(sessions))
137+
assertEqualSessions(t, s1, sessions[0])
138+
assertEqualSessions(t, s2, sessions[1])
139+
assertEqualSessions(t, s3, sessions[2])
140+
141+
sessions, err = db.ListSessionsByState()
89142
require.NoError(t, err)
90-
require.Equal(t, session1.State, StateRevoked)
143+
require.Empty(t, sessions)
144+
145+
sessions, err = db.ListSessionsByState(StateInUse)
146+
require.NoError(t, err)
147+
require.Empty(t, sessions)
91148
}
92149

93150
// TestLinkingSessions tests that session linking works as expected.
@@ -101,10 +158,10 @@ func TestLinkingSessions(t *testing.T) {
101158
})
102159

103160
// Create a new session with no previous link.
104-
s1 := newSession(t, db, clock, "session 1", nil)
161+
s1 := newSession(t, db, clock, "session 1")
105162

106163
// Create another session and link it to the first.
107-
s2 := newSession(t, db, clock, "session 2", &s1.GroupID)
164+
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))
108165

109166
// Try to persist the second session and assert that it fails due to the
110167
// linked session not existing in the DB yet.
@@ -141,9 +198,9 @@ func TestLinkedSessions(t *testing.T) {
141198
// after are all linked to the prior one. All these sessions belong to
142199
// the same group. The group ID is equivalent to the session ID of the
143200
// first session.
144-
s1 := newSession(t, db, clock, "session 1", nil)
145-
s2 := newSession(t, db, clock, "session 2", &s1.GroupID)
146-
s3 := newSession(t, db, clock, "session 3", &s2.GroupID)
201+
s1 := newSession(t, db, clock, "session 1")
202+
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))
203+
s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID))
147204

148205
// Persist the sessions.
149206
require.NoError(t, db.CreateSession(s1))
@@ -169,8 +226,8 @@ func TestLinkedSessions(t *testing.T) {
169226

170227
// To ensure that different groups don't interfere with each other,
171228
// let's add another set of linked sessions not linked to the first.
172-
s4 := newSession(t, db, clock, "session 4", nil)
173-
s5 := newSession(t, db, clock, "session 5", &s4.GroupID)
229+
s4 := newSession(t, db, clock, "session 4")
230+
s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID))
174231

175232
require.NotEqual(t, s4.GroupID, s1.GroupID)
176233

@@ -209,7 +266,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
209266
// function is checked correctly.
210267

211268
// Add a new session to the DB.
212-
s1 := newSession(t, db, clock, "label 1", nil)
269+
s1 := newSession(t, db, clock, "label 1")
213270
require.NoError(t, db.CreateSession(s1))
214271

215272
// Check that the group passes against an appropriate predicate.
@@ -234,7 +291,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
234291
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
235292

236293
// Add a new session to the same group as the first one.
237-
s2 := newSession(t, db, clock, "label 2", &s1.GroupID)
294+
s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID))
238295
require.NoError(t, db.CreateSession(s2))
239296

240297
// Check that the group passes against an appropriate predicate.
@@ -256,7 +313,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
256313
require.False(t, ok)
257314

258315
// Add a new session that is not linked to the first one.
259-
s3 := newSession(t, db, clock, "completely different", nil)
316+
s3 := newSession(t, db, clock, "completely different")
260317
require.NoError(t, db.CreateSession(s3))
261318

262319
// Ensure that the first group is unaffected.
@@ -286,8 +343,24 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
286343
require.True(t, ok)
287344
}
288345

346+
// testSessionModifier is a functional option that can be used to modify the
347+
// default test session created by newSession.
348+
type testSessionModifier func(*Session)
349+
350+
func withLinkedGroupID(groupID *ID) testSessionModifier {
351+
return func(s *Session) {
352+
s.GroupID = *groupID
353+
}
354+
}
355+
356+
func withType(t Type) testSessionModifier {
357+
return func(s *Session) {
358+
s.Type = t
359+
}
360+
}
361+
289362
func newSession(t *testing.T, db Store, clock clock.Clock, label string,
290-
linkedGroupID *ID) *Session {
363+
mods ...testSessionModifier) *Session {
291364

292365
id, priv, err := db.GetUnusedIDAndKeyPair()
293366
require.NoError(t, err)
@@ -296,11 +369,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string,
296369
id, priv, label, TypeMacaroonAdmin,
297370
clock.Now(),
298371
time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC),
299-
"foo.bar.baz:1234", true, nil, nil, nil, true, linkedGroupID,
372+
"foo.bar.baz:1234", true, nil, nil, nil, true, nil,
300373
[]PrivacyFlag{ClearPubkeys},
301374
)
302375
require.NoError(t, err)
303376

377+
for _, mod := range mods {
378+
mod(session)
379+
}
380+
304381
return session
305382
}
306383

session_rpcserver.go

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
101101
// requests. This includes resuming all non-revoked sessions.
102102
func (s *sessionRpcServer) start(ctx context.Context) error {
103103
// Start up all previously created sessions.
104-
sessions, err := s.cfg.db.ListSessions(nil)
104+
sessions, err := s.cfg.db.ListSessionsByState(
105+
session.StateCreated,
106+
session.StateInUse,
107+
)
105108
if err != nil {
106109
return fmt.Errorf("error listing sessions: %v", err)
107110
}
@@ -126,12 +129,6 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
126129
continue
127130
}
128131

129-
if sess.State != session.StateInUse &&
130-
sess.State != session.StateCreated {
131-
132-
continue
133-
}
134-
135132
if sess.Expiry.Before(time.Now()) {
136133
continue
137134
}
@@ -345,24 +342,13 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,
345342
}, nil
346343
}
347344

348-
// resumeSession tries to start an existing session if it is not expired, not
349-
// revoked and a LiT session.
345+
// resumeSession tries to start the given session if it is not expired.
350346
func (s *sessionRpcServer) resumeSession(ctx context.Context,
351347
sess *session.Session) error {
352348

353349
pubKey := sess.LocalPublicKey
354350
pubKeyBytes := pubKey.SerializeCompressed()
355351

356-
// We only start non-revoked, non-expired LiT sessions. Everything else
357-
// we just skip.
358-
if sess.State != session.StateInUse &&
359-
sess.State != session.StateCreated {
360-
361-
log.Debugf("Not resuming session %x with state %d", pubKeyBytes,
362-
sess.State)
363-
return nil
364-
}
365-
366352
// Don't resume an expired session.
367353
if sess.Expiry.Before(time.Now()) {
368354
log.Debugf("Not resuming session %x with expiry %s",
@@ -536,7 +522,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
536522
func (s *sessionRpcServer) ListSessions(_ context.Context,
537523
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {
538524

539-
sessions, err := s.cfg.db.ListSessions(nil)
525+
sessions, err := s.cfg.db.ListAllSessions()
540526
if err != nil {
541527
return nil, fmt.Errorf("error fetching sessions: %v", err)
542528
}
@@ -1259,9 +1245,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context,
12591245
_ *litrpc.ListAutopilotSessionsRequest) (
12601246
*litrpc.ListAutopilotSessionsResponse, error) {
12611247

1262-
sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool {
1263-
return s.Type == session.TypeAutopilot
1264-
})
1248+
sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot)
12651249
if err != nil {
12661250
return nil, fmt.Errorf("error fetching sessions: %v", err)
12671251
}

0 commit comments

Comments
 (0)