Skip to content

Commit

Permalink
Merge pull request #970 from ellemouton/sql12Sessions4
Browse files Browse the repository at this point in the history
[sql-12] sessions: make ListSession methods SQL ready
  • Loading branch information
ellemouton authored Feb 13, 2025
2 parents d56ba68 + 14cb0be commit c749126
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 46 deletions.
11 changes: 9 additions & 2 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,15 @@ type Store interface {
// GetSession fetches the session with the given key.
GetSession(key *btcec.PublicKey) (*Session, error)

// ListSessions returns all sessions currently known to the store.
ListSessions(filterFn func(s *Session) bool) ([]*Session, error)
// ListAllSessions returns all sessions currently known to the store.
ListAllSessions() ([]*Session, error)

// ListSessionsByType returns all sessions of the given type.
ListSessionsByType(t Type) ([]*Session, error)

// ListSessionsByState returns all sessions currently known to the store
// that are in the given states.
ListSessionsByState(...State) ([]*Session, error)

// RevokeSession updates the state of the session with the given local
// public key to be revoked.
Expand Down
46 changes: 44 additions & 2 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"time"

"github.com/btcsuite/btcd/btcec/v2"
Expand Down Expand Up @@ -363,10 +364,46 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
return session, nil
}

// ListSessions returns all sessions currently known to the store.
// ListAllSessions returns all sessions currently known to the store.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, error) {
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
return true
})
}

// ListSessionsByType returns all sessions currently known to the store that
// have the given type.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
return s.Type == t
})
}

// ListSessionsByState returns all sessions currently known to the store that
// are in the given states.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
for _, state := range states {
if s.State == state {
return true
}
}

return false
})
}

// listSessions returns all sessions currently known to the store that pass the
// given filter function.
func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
error) {

var sessions []*Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -399,6 +436,11 @@ func (db *BoltStore) ListSessions(filterFn func(s *Session) bool) ([]*Session, e
return nil, err
}

// Make sure to sort the sessions by creation time.
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].CreatedAt.Before(sessions[j].CreatedAt)
})

return sessions, nil
}

Expand Down
115 changes: 96 additions & 19 deletions session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@ func TestBasicSessionStore(t *testing.T) {
_ = db.Close()
})

// Create a few sessions.
s1 := newSession(t, db, clock, "session 1", nil)
s2 := newSession(t, db, clock, "session 2", nil)
s3 := newSession(t, db, clock, "session 3", nil)
s4 := newSession(t, db, clock, "session 4", nil)
// Create a few sessions. We increment the time by one second between
// each session to ensure that the created at time is unique and hence
// that the ListSessions method returns the sessions in a deterministic
// order.
s1 := newSession(t, db, clock, "session 1")
clock.SetTime(testTime.Add(time.Second))
s2 := newSession(t, db, clock, "session 2")
clock.SetTime(testTime.Add(2 * time.Second))
s3 := newSession(t, db, clock, "session 3", withType(TypeAutopilot))
clock.SetTime(testTime.Add(3 * time.Second))
s4 := newSession(t, db, clock, "session 4")

// Persist session 1. This should now succeed.
require.NoError(t, db.CreateSession(s1))
Expand All @@ -50,6 +56,22 @@ func TestBasicSessionStore(t *testing.T) {
require.NoError(t, db.CreateSession(s2))
require.NoError(t, db.CreateSession(s3))

// Test the ListSessionsByType method.
sessions, err := db.ListSessionsByType(TypeMacaroonAdmin)
require.NoError(t, err)
require.Equal(t, 2, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])

sessions, err = db.ListSessionsByType(TypeAutopilot)
require.NoError(t, err)
require.Equal(t, 1, len(sessions))
assertEqualSessions(t, s3, sessions[0])

sessions, err = db.ListSessionsByType(TypeMacaroonReadonly)
require.NoError(t, err)
require.Empty(t, sessions)

// Ensure that we can retrieve each session by both its local pub key
// and by its ID.
for _, s := range []*Session{s1, s2, s3} {
Expand Down Expand Up @@ -85,9 +107,44 @@ func TestBasicSessionStore(t *testing.T) {

// Now revoke the session and assert that the state is revoked.
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
session1, err = db.GetSession(s1.LocalPublicKey)
s1, err = db.GetSession(s1.LocalPublicKey)
require.NoError(t, err)
require.Equal(t, s1.State, StateRevoked)

// Test that ListAllSessions works.
sessions, err = db.ListAllSessions()
require.NoError(t, err)
require.Equal(t, 3, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])
assertEqualSessions(t, s3, sessions[2])

// Test that ListSessionsByState works.
sessions, err = db.ListSessionsByState(StateRevoked)
require.NoError(t, err)
require.Equal(t, 1, len(sessions))
assertEqualSessions(t, s1, sessions[0])

sessions, err = db.ListSessionsByState(StateCreated)
require.NoError(t, err)
require.Equal(t, 2, len(sessions))
assertEqualSessions(t, s2, sessions[0])
assertEqualSessions(t, s3, sessions[1])

sessions, err = db.ListSessionsByState(StateCreated, StateRevoked)
require.NoError(t, err)
require.Equal(t, 3, len(sessions))
assertEqualSessions(t, s1, sessions[0])
assertEqualSessions(t, s2, sessions[1])
assertEqualSessions(t, s3, sessions[2])

sessions, err = db.ListSessionsByState()
require.NoError(t, err)
require.Equal(t, session1.State, StateRevoked)
require.Empty(t, sessions)

sessions, err = db.ListSessionsByState(StateInUse)
require.NoError(t, err)
require.Empty(t, sessions)
}

// TestLinkingSessions tests that session linking works as expected.
Expand All @@ -101,10 +158,10 @@ func TestLinkingSessions(t *testing.T) {
})

// Create a new session with no previous link.
s1 := newSession(t, db, clock, "session 1", nil)
s1 := newSession(t, db, clock, "session 1")

// Create another session and link it to the first.
s2 := newSession(t, db, clock, "session 2", &s1.GroupID)
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))

// Try to persist the second session and assert that it fails due to the
// linked session not existing in the DB yet.
Expand Down Expand Up @@ -141,9 +198,9 @@ func TestLinkedSessions(t *testing.T) {
// after are all linked to the prior one. All these sessions belong to
// the same group. The group ID is equivalent to the session ID of the
// first session.
s1 := newSession(t, db, clock, "session 1", nil)
s2 := newSession(t, db, clock, "session 2", &s1.GroupID)
s3 := newSession(t, db, clock, "session 3", &s2.GroupID)
s1 := newSession(t, db, clock, "session 1")
s2 := newSession(t, db, clock, "session 2", withLinkedGroupID(&s1.GroupID))
s3 := newSession(t, db, clock, "session 3", withLinkedGroupID(&s2.GroupID))

// Persist the sessions.
require.NoError(t, db.CreateSession(s1))
Expand All @@ -169,8 +226,8 @@ func TestLinkedSessions(t *testing.T) {

// To ensure that different groups don't interfere with each other,
// let's add another set of linked sessions not linked to the first.
s4 := newSession(t, db, clock, "session 4", nil)
s5 := newSession(t, db, clock, "session 5", &s4.GroupID)
s4 := newSession(t, db, clock, "session 4")
s5 := newSession(t, db, clock, "session 5", withLinkedGroupID(&s4.GroupID))

require.NotEqual(t, s4.GroupID, s1.GroupID)

Expand Down Expand Up @@ -209,7 +266,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
// function is checked correctly.

// Add a new session to the DB.
s1 := newSession(t, db, clock, "label 1", nil)
s1 := newSession(t, db, clock, "label 1")
require.NoError(t, db.CreateSession(s1))

// Check that the group passes against an appropriate predicate.
Expand All @@ -234,7 +291,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))

// Add a new session to the same group as the first one.
s2 := newSession(t, db, clock, "label 2", &s1.GroupID)
s2 := newSession(t, db, clock, "label 2", withLinkedGroupID(&s1.GroupID))
require.NoError(t, db.CreateSession(s2))

// Check that the group passes against an appropriate predicate.
Expand All @@ -256,7 +313,7 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
require.False(t, ok)

// Add a new session that is not linked to the first one.
s3 := newSession(t, db, clock, "completely different", nil)
s3 := newSession(t, db, clock, "completely different")
require.NoError(t, db.CreateSession(s3))

// Ensure that the first group is unaffected.
Expand Down Expand Up @@ -286,8 +343,24 @@ func TestCheckSessionGroupPredicate(t *testing.T) {
require.True(t, ok)
}

// testSessionModifier is a functional option that can be used to modify the
// default test session created by newSession.
type testSessionModifier func(*Session)

func withLinkedGroupID(groupID *ID) testSessionModifier {
return func(s *Session) {
s.GroupID = *groupID
}
}

func withType(t Type) testSessionModifier {
return func(s *Session) {
s.Type = t
}
}

func newSession(t *testing.T, db Store, clock clock.Clock, label string,
linkedGroupID *ID) *Session {
mods ...testSessionModifier) *Session {

id, priv, err := db.GetUnusedIDAndKeyPair()
require.NoError(t, err)
Expand All @@ -296,11 +369,15 @@ func newSession(t *testing.T, db Store, clock clock.Clock, label string,
id, priv, label, TypeMacaroonAdmin,
clock.Now(),
time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC),
"foo.bar.baz:1234", true, nil, nil, nil, true, linkedGroupID,
"foo.bar.baz:1234", true, nil, nil, nil, true, nil,
[]PrivacyFlag{ClearPubkeys},
)
require.NoError(t, err)

for _, mod := range mods {
mod(session)
}

return session
}

Expand Down
30 changes: 7 additions & 23 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer,
// requests. This includes resuming all non-revoked sessions.
func (s *sessionRpcServer) start(ctx context.Context) error {
// Start up all previously created sessions.
sessions, err := s.cfg.db.ListSessions(nil)
sessions, err := s.cfg.db.ListSessionsByState(
session.StateCreated,
session.StateInUse,
)
if err != nil {
return fmt.Errorf("error listing sessions: %v", err)
}
Expand All @@ -126,12 +129,6 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
continue
}

if sess.State != session.StateInUse &&
sess.State != session.StateCreated {

continue
}

if sess.Expiry.Before(time.Now()) {
continue
}
Expand Down Expand Up @@ -345,24 +342,13 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,
}, nil
}

// resumeSession tries to start an existing session if it is not expired, not
// revoked and a LiT session.
// resumeSession tries to start the given session if it is not expired.
func (s *sessionRpcServer) resumeSession(ctx context.Context,
sess *session.Session) error {

pubKey := sess.LocalPublicKey
pubKeyBytes := pubKey.SerializeCompressed()

// We only start non-revoked, non-expired LiT sessions. Everything else
// we just skip.
if sess.State != session.StateInUse &&
sess.State != session.StateCreated {

log.Debugf("Not resuming session %x with state %d", pubKeyBytes,
sess.State)
return nil
}

// Don't resume an expired session.
if sess.Expiry.Before(time.Now()) {
log.Debugf("Not resuming session %x with expiry %s",
Expand Down Expand Up @@ -536,7 +522,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
func (s *sessionRpcServer) ListSessions(_ context.Context,
_ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) {

sessions, err := s.cfg.db.ListSessions(nil)
sessions, err := s.cfg.db.ListAllSessions()
if err != nil {
return nil, fmt.Errorf("error fetching sessions: %v", err)
}
Expand Down Expand Up @@ -1259,9 +1245,7 @@ func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context,
_ *litrpc.ListAutopilotSessionsRequest) (
*litrpc.ListAutopilotSessionsResponse, error) {

sessions, err := s.cfg.db.ListSessions(func(s *session.Session) bool {
return s.Type == session.TypeAutopilot
})
sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot)
if err != nil {
return nil, fmt.Errorf("error fetching sessions: %v", err)
}
Expand Down

0 comments on commit c749126

Please sign in to comment.