diff --git a/session/interface.go b/session/interface.go index 41bd354cd..6b27a1f38 100644 --- a/session/interface.go +++ b/session/interface.go @@ -161,8 +161,15 @@ type Store interface { // GetSession fetches the session with the given key. GetSession(key *btcec.PublicKey) (*Session, error) - // ListSessions returns all sessions currently known to the store. - ListSessions(filterFn func(s *Session) bool) ([]*Session, error) + // ListAllSessions returns all sessions currently known to the store. + ListAllSessions() ([]*Session, error) + + // ListSessionsByType returns all sessions of the given type. + ListSessionsByType(t Type) ([]*Session, error) + + // ListSessionsByState returns all sessions currently known to the store + // that are in the given states. + ListSessionsByState(...State) ([]*Session, error) // RevokeSession updates the state of the session with the given local // public key to be revoked. diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 19c7f7db2..2f5f252d5 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -363,10 +364,46 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) { return session, nil } -// ListSessions returns all sessions currently known to the store. +// ListAllSessions returns all sessions currently known to the store. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) { +func (db *BoltStore) ListAllSessions() ([]*Session, error) { + return db.listSessions(func(s *Session) bool { + return true + }) +} + +// ListSessionsByType returns all sessions currently known to the store that +// have the given type. +// +// NOTE: this is part of the Store interface. +func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) { + return db.listSessions(func(s *Session) bool { + return s.Type == t + }) +} + +// ListSessionsByState returns all sessions currently known to the store that +// are in the given states. +// +// NOTE: this is part of the Store interface. +func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) { + return db.listSessions(func(s *Session) bool { + for _, state := range states { + if s.State == state { + return true + } + } + + return false + }) +} + +// listSessions returns all sessions currently known to the store that pass the +// given filter function. +func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session, + error) { + var sessions []*Session err := db.View(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) @@ -399,6 +436,11 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e return nil, err } + // Make sure to sort the sessions by creation time. + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].CreatedAt.Before(sessions[j].CreatedAt) + }) + return sessions, nil } diff --git a/session/store_test.go b/session/store_test.go index 18bd933d6..e5530fdda 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -23,11 +23,17 @@ func TestBasicSessionStore(t *testing.T) { _ = db.Close() }) - // Create a few sessions. - s1 := newSession(t, db, clock, "session 1", nil) - s2 := newSession(t, db, clock, "session 2", nil) - s3 := newSession(t, db, clock, "session 3", nil) - s4 := newSession(t, db, clock, "session 4", nil) + // Create a few sessions. We increment the time by one second between + // each session to ensure that the created at time is unique and hence + // that the ListSessions method returns the sessions in a deterministic + // order. + s1 := newSession(t, db, clock, "session 1") + clock.SetTime(testTime.Add(time.Second)) + s2 := newSession(t, db, clock, "session 2") + clock.SetTime(testTime.Add(2 * time.Second)) + s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot)) + clock.SetTime(testTime.Add(3 * time.Second)) + s4 := newSession(t, db, clock, "session 4") // Persist session 1. This should now succeed. require.NoError(t, db.CreateSession(s1)) @@ -50,6 +56,22 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, db.CreateSession(s2)) require.NoError(t, db.CreateSession(s3)) + // Test the ListSessionsByType method. + sessions, err := db.ListSessionsByType(TypeMacaroonAdmin) + require.NoError(t, err) + require.Equal(t, 2, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + + sessions, err = db.ListSessionsByType(TypeAutopilot) + require.NoError(t, err) + require.Equal(t, 1, len(sessions)) + assertEqualSessions(t, s3, sessions[0]) + + sessions, err = db.ListSessionsByType(TypeMacaroonReadonly) + require.NoError(t, err) + require.Empty(t, sessions) + // Ensure that we can retrieve each session by both its local pub key // and by its ID. for _, s := range []*Session{s1, s2, s3} { @@ -85,9 +107,44 @@ func TestBasicSessionStore(t *testing.T) { // Now revoke the session and assert that the state is revoked. require.NoError(t, db.RevokeSession(s1.LocalPublicKey)) - session1, err = db.GetSession(s1.LocalPublicKey) + s1, err = db.GetSession(s1.LocalPublicKey) + require.NoError(t, err) + require.Equal(t, s1.State, StateRevoked) + + // Test that ListAllSessions works. + sessions, err = db.ListAllSessions() + require.NoError(t, err) + require.Equal(t, 3, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + assertEqualSessions(t, s3, sessions[2]) + + // Test that ListSessionsByState works. + sessions, err = db.ListSessionsByState(StateRevoked) + require.NoError(t, err) + require.Equal(t, 1, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + + sessions, err = db.ListSessionsByState(StateCreated) + require.NoError(t, err) + require.Equal(t, 2, len(sessions)) + assertEqualSessions(t, s2, sessions[0]) + assertEqualSessions(t, s3, sessions[1]) + + sessions, err = db.ListSessionsByState(StateCreated, StateRevoked) + require.NoError(t, err) + require.Equal(t, 3, len(sessions)) + assertEqualSessions(t, s1, sessions[0]) + assertEqualSessions(t, s2, sessions[1]) + assertEqualSessions(t, s3, sessions[2]) + + sessions, err = db.ListSessionsByState() require.NoError(t, err) - require.Equal(t, session1.State, StateRevoked) + require.Empty(t, sessions) + + sessions, err = db.ListSessionsByState(StateInUse) + require.NoError(t, err) + require.Empty(t, sessions) } // TestLinkingSessions tests that session linking works as expected. @@ -101,10 +158,10 @@ func TestLinkingSessions(t *testing.T) { }) // Create a new session with no previous link. - s1 := newSession(t, db, clock, "session 1", nil) + s1 := newSession(t, db, clock, "session 1") // Create another session and link it to the first. - s2 := newSession(t, db, clock, "session 2", &s1.GroupID) + s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID)) // Try to persist the second session and assert that it fails due to the // linked session not existing in the DB yet. @@ -141,9 +198,9 @@ func TestLinkedSessions(t *testing.T) { // after are all linked to the prior one. All these sessions belong to // the same group. The group ID is equivalent to the session ID of the // first session. - s1 := newSession(t, db, clock, "session 1", nil) - s2 := newSession(t, db, clock, "session 2", &s1.GroupID) - s3 := newSession(t, db, clock, "session 3", &s2.GroupID) + s1 := newSession(t, db, clock, "session 1") + s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID)) + s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID)) // Persist the sessions. require.NoError(t, db.CreateSession(s1)) @@ -169,8 +226,8 @@ func TestLinkedSessions(t *testing.T) { // To ensure that different groups don't interfere with each other, // let's add another set of linked sessions not linked to the first. - s4 := newSession(t, db, clock, "session 4", nil) - s5 := newSession(t, db, clock, "session 5", &s4.GroupID) + s4 := newSession(t, db, clock, "session 4") + s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID)) require.NotEqual(t, s4.GroupID, s1.GroupID) @@ -209,7 +266,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { // function is checked correctly. // Add a new session to the DB. - s1 := newSession(t, db, clock, "label 1", nil) + s1 := newSession(t, db, clock, "label 1") require.NoError(t, db.CreateSession(s1)) // Check that the group passes against an appropriate predicate. @@ -234,7 +291,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.NoError(t, db.RevokeSession(s1.LocalPublicKey)) // Add a new session to the same group as the first one. - s2 := newSession(t, db, clock, "label 2", &s1.GroupID) + s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID)) require.NoError(t, db.CreateSession(s2)) // Check that the group passes against an appropriate predicate. @@ -256,7 +313,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.False(t, ok) // Add a new session that is not linked to the first one. - s3 := newSession(t, db, clock, "completely different", nil) + s3 := newSession(t, db, clock, "completely different") require.NoError(t, db.CreateSession(s3)) // Ensure that the first group is unaffected. @@ -286,8 +343,24 @@ func TestCheckSessionGroupPredicate(t *testing.T) { require.True(t, ok) } +// testSessionModifier is a functional option that can be used to modify the +// default test session created by newSession. +type testSessionModifier func(*Session) + +func withLinkedGroupID(groupID *ID) testSessionModifier { + return func(s *Session) { + s.GroupID = *groupID + } +} + +func withType(t Type) testSessionModifier { + return func(s *Session) { + s.Type = t + } +} + func newSession(t *testing.T, db Store, clock clock.Clock, label string, - linkedGroupID *ID) *Session { + mods ...testSessionModifier) *Session { id, priv, err := db.GetUnusedIDAndKeyPair() require.NoError(t, err) @@ -296,11 +369,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string, id, priv, label, TypeMacaroonAdmin, clock.Now(), time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), - "foo.bar.baz:1234", true, nil, nil, nil, true, linkedGroupID, + "foo.bar.baz:1234", true, nil, nil, nil, true, nil, []PrivacyFlag{ClearPubkeys}, ) require.NoError(t, err) + for _, mod := range mods { + mod(session) + } + return session } diff --git a/session_rpcserver.go b/session_rpcserver.go index 666744cd0..e85d3578f 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -101,7 +101,10 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // requests. This includes resuming all non-revoked sessions. func (s *sessionRpcServer) start(ctx context.Context) error { // Start up all previously created sessions. - sessions, err := s.cfg.db.ListSessions(nil) + sessions, err := s.cfg.db.ListSessionsByState( + session.StateCreated, + session.StateInUse, + ) if err != nil { return fmt.Errorf("error listing sessions: %v", err) } @@ -126,12 +129,6 @@ func (s *sessionRpcServer) start(ctx context.Context) error { continue } - if sess.State != session.StateInUse && - sess.State != session.StateCreated { - - continue - } - if sess.Expiry.Before(time.Now()) { continue } @@ -345,24 +342,13 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, }, nil } -// resumeSession tries to start an existing session if it is not expired, not -// revoked and a LiT session. +// resumeSession tries to start the given session if it is not expired. func (s *sessionRpcServer) resumeSession(ctx context.Context, sess *session.Session) error { pubKey := sess.LocalPublicKey pubKeyBytes := pubKey.SerializeCompressed() - // We only start non-revoked, non-expired LiT sessions. Everything else - // we just skip. - if sess.State != session.StateInUse && - sess.State != session.StateCreated { - - log.Debugf("Not resuming session %x with state %d", pubKeyBytes, - sess.State) - return nil - } - // Don't resume an expired session. if sess.Expiry.Before(time.Now()) { log.Debugf("Not resuming session %x with expiry %s", @@ -536,7 +522,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context, func (s *sessionRpcServer) ListSessions(_ context.Context, _ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) { - sessions, err := s.cfg.db.ListSessions(nil) + sessions, err := s.cfg.db.ListAllSessions() if err != nil { return nil, fmt.Errorf("error fetching sessions: %v", err) } @@ -1259,9 +1245,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context, _ *litrpc.ListAutopilotSessionsRequest) ( *litrpc.ListAutopilotSessionsResponse, error) { - sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool { - return s.Type == session.TypeAutopilot - }) + sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot) if err != nil { return nil, fmt.Errorf("error fetching sessions: %v", err) }