Skip to content

Commit 980a36d

Browse files
committed
session: remove the filter fn in ListSessions
And instead let the caller pass in a list of States they are interested in. This will make SQL queries much more efficient since we can index by state.
1 parent 9590aeb commit 980a36d

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

session/interface.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,10 @@ type Store interface {
165165
// GetSession fetches the session with the given key.
166166
GetSession(key *btcec.PublicKey) (*Session, error)
167167

168-
// ListSessions returns all sessions currently known to the store.
169-
ListSessions(filterFn func(s *Session) bool) ([]*Session, error)
168+
// ListSessions returns all sessions currently known to the store that
169+
// are in the given states. If no states are provided, all sessions are
170+
// returned.
171+
ListSessions(states ...State) ([]*Session, error)
170172

171173
// ListSessionsByType returns all sessions of the given type.
172174
ListSessionsByType(t Type) ([]*Session, error)

session/store.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,24 @@ func (db *DB) GetSession(key *btcec.PublicKey) (*Session, error) {
245245
return session, nil
246246
}
247247

248-
// ListSessions returns all sessions currently known to the store.
248+
// ListSessions returns all sessions currently known to the store that are in
249+
// the given states. If no states are provided, all sessions are returned.
249250
//
250251
// NOTE: this is part of the Store interface.
251-
func (db *DB) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
252-
return db.listSessions(filterFn)
252+
func (db *DB) ListSessions(states ...State) ([]*Session, error) {
253+
return db.listSessions(func(s *Session) bool {
254+
if len(states) == 0 {
255+
return true
256+
}
257+
258+
for _, state := range states {
259+
if s.State == state {
260+
return true
261+
}
262+
}
263+
264+
return false
265+
})
253266
}
254267

255268
// ListSessionsByType returns all sessions currently known to the store that

session/store_test.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,8 @@ func TestBasicSessionStore(t *testing.T) {
5656
require.NoError(t, db.CreateSession(s2))
5757
require.NoError(t, db.CreateSession(s3))
5858

59-
// Check that all sessions are returned in ListSessions.
60-
sessions, err := db.ListSessions(nil)
61-
require.NoError(t, err)
62-
require.Equal(t, 3, len(sessions))
63-
assertEqualSessions(t, s1, sessions[0])
64-
assertEqualSessions(t, s2, sessions[1])
65-
assertEqualSessions(t, s3, sessions[2])
66-
6759
// Test the ListSessionsByType method.
68-
sessions, err = db.ListSessionsByType(TypeMacaroonAdmin)
60+
sessions, err := db.ListSessionsByType(TypeMacaroonAdmin)
6961
require.NoError(t, err)
7062
require.Equal(t, 2, len(sessions))
7163
assertEqualSessions(t, s1, sessions[0])
@@ -115,9 +107,39 @@ func TestBasicSessionStore(t *testing.T) {
115107

116108
// Now revoke the session and assert that the state is revoked.
117109
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
118-
session1, err = db.GetSession(s1.LocalPublicKey)
110+
s1, err = db.GetSession(s1.LocalPublicKey)
119111
require.NoError(t, err)
120-
require.Equal(t, session1.State, StateRevoked)
112+
require.Equal(t, s1.State, StateRevoked)
113+
114+
// Test that ListSessions by certain states works.
115+
sessions, err = db.ListSessions(StateRevoked)
116+
require.NoError(t, err)
117+
require.Equal(t, 1, len(sessions))
118+
assertEqualSessions(t, s1, sessions[0])
119+
120+
sessions, err = db.ListSessions(StateCreated)
121+
require.NoError(t, err)
122+
require.Equal(t, 2, len(sessions))
123+
assertEqualSessions(t, s2, sessions[0])
124+
assertEqualSessions(t, s3, sessions[1])
125+
126+
sessions, err = db.ListSessions(StateCreated, StateRevoked)
127+
require.NoError(t, err)
128+
require.Equal(t, 3, len(sessions))
129+
assertEqualSessions(t, s1, sessions[0])
130+
assertEqualSessions(t, s2, sessions[1])
131+
assertEqualSessions(t, s3, sessions[2])
132+
133+
sessions, err = db.ListSessions()
134+
require.NoError(t, err)
135+
require.Equal(t, 3, len(sessions))
136+
assertEqualSessions(t, s1, sessions[0])
137+
assertEqualSessions(t, s2, sessions[1])
138+
assertEqualSessions(t, s3, sessions[2])
139+
140+
sessions, err = db.ListSessions(StateInUse)
141+
require.NoError(t, err)
142+
require.Empty(t, sessions)
121143
}
122144

123145
// TestLinkingSessions tests that session linking works as expected.

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
100100
// requests. This includes resuming all non-revoked sessions.
101101
func (s *sessionRpcServer) start(ctx context.Context) error {
102102
// Start up all previously created sessions.
103-
sessions, err := s.cfg.db.ListSessions(nil)
103+
sessions, err := s.cfg.db.ListSessions()
104104
if err != nil {
105105
return fmt.Errorf("error listing sessions: %v", err)
106106
}
@@ -543,7 +543,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
543543
func (s *sessionRpcServer) ListSessions(_ context.Context,
544544
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {
545545

546-
sessions, err := s.cfg.db.ListSessions(nil)
546+
sessions, err := s.cfg.db.ListSessions()
547547
if err != nil {
548548
return nil, fmt.Errorf("error fetching sessions: %v", err)
549549
}

0 commit comments

Comments
 (0)