diff --git a/firewall/privacy_mapper.go b/firewall/privacy_mapper.go index 7a2f8fe42..af4f3b0af 100644 --- a/firewall/privacy_mapper.go +++ b/firewall/privacy_mapper.go @@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context, uri string, req proto.Message, sessionID session.ID) (proto.Message, error) { - session, err := p.sessionDB.GetSessionByID(sessionID) + session, err := p.sessionDB.GetSessionByID(ctx, sessionID) if err != nil { return nil, err } @@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context, func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string, resp proto.Message, sessionID session.ID) (proto.Message, error) { - session, err := p.sessionDB.GetSessionByID(sessionID) + session, err := p.sessionDB.GetSessionByID(ctx, sessionID) if err != nil { return nil, err } diff --git a/firewall/rule_enforcer.go b/firewall/rule_enforcer.go index 964baf32a..008af72c5 100644 --- a/firewall/rule_enforcer.go +++ b/firewall/rule_enforcer.go @@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string, return nil, err } - session, err := r.sessionDB.GetSessionByID(sessionID) + session, err := r.sessionDB.GetSessionByID(ctx, sessionID) if err != nil { return nil, err } diff --git a/firewalldb/actions.go b/firewalldb/actions.go index 8ddd1e3b5..657f19cea 100644 --- a/firewalldb/actions.go +++ b/firewalldb/actions.go @@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID, // pass the filterFn requirements. // // TODO: update to allow for pagination. -func (db *DB) ListGroupActions(groupID session.ID, +func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID, filterFn ListActionsFilterFn) ([]*Action, error) { if filterFn == nil { @@ -400,7 +400,7 @@ func (db *DB) ListGroupActions(groupID session.ID, } } - sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID) + sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID) if err != nil { return nil, err } @@ -629,11 +629,11 @@ type groupActionsReadDB struct { var _ ActionsDB = (*groupActionsReadDB)(nil) // ListActions will return all the Actions for a particular group. -func (s *groupActionsReadDB) ListActions(_ context.Context) ([]*RuleAction, +func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction, error) { sessionActions, err := s.db.ListGroupActions( - s.groupID, func(a *Action, _ bool) (bool, bool) { + ctx, s.groupID, func(a *Action, _ bool) (bool, bool) { return a.State == ActionStateDone, true }, ) @@ -660,11 +660,11 @@ var _ ActionsDB = (*groupFeatureActionsReadDB)(nil) // ListActions will return all the Actions for a particular group that were // executed by a particular feature. -func (a *groupFeatureActionsReadDB) ListActions(_ context.Context) ( +func (a *groupFeatureActionsReadDB) ListActions(ctx context.Context) ( []*RuleAction, error) { featureActions, err := a.db.ListGroupActions( - a.groupID, func(action *Action, _ bool) (bool, bool) { + ctx, a.groupID, func(action *Action, _ bool) (bool, bool) { return action.State == ActionStateDone && action.FeatureName == a.featureName, true }, diff --git a/firewalldb/actions_test.go b/firewalldb/actions_test.go index 72b6376b2..8b66529f4 100644 --- a/firewalldb/actions_test.go +++ b/firewalldb/actions_test.go @@ -1,6 +1,7 @@ package firewalldb import ( + "context" "fmt" "testing" "time" @@ -342,6 +343,9 @@ func TestListActions(t *testing.T) { // TestListGroupActions tests that the ListGroupActions correctly returns all // actions in a particular session group. func TestListGroupActions(t *testing.T) { + t.Parallel() + ctx := context.Background() + group1 := intToSessionID(0) // Link session 1 and session 2 to group 1. @@ -356,7 +360,7 @@ func TestListGroupActions(t *testing.T) { }) // There should not be any actions in group 1 yet. - al, err := db.ListGroupActions(group1, nil) + al, err := db.ListGroupActions(ctx, group1, nil) require.NoError(t, err) require.Empty(t, al) @@ -365,7 +369,7 @@ func TestListGroupActions(t *testing.T) { require.NoError(t, err) // There should now be one action in the group. - al, err = db.ListGroupActions(group1, nil) + al, err = db.ListGroupActions(ctx, group1, nil) require.NoError(t, err) require.Len(t, al, 1) require.Equal(t, sessionID1, al[0].SessionID) @@ -375,7 +379,7 @@ func TestListGroupActions(t *testing.T) { require.NoError(t, err) // There should now be actions in the group. - al, err = db.ListGroupActions(group1, nil) + al, err = db.ListGroupActions(ctx, group1, nil) require.NoError(t, err) require.Len(t, al, 2) require.Equal(t, sessionID1, al[0].SessionID) diff --git a/firewalldb/interface.go b/firewalldb/interface.go index 6e6509573..86e638b53 100644 --- a/firewalldb/interface.go +++ b/firewalldb/interface.go @@ -1,6 +1,10 @@ package firewalldb -import "github.com/lightninglabs/lightning-terminal/session" +import ( + "context" + + "github.com/lightninglabs/lightning-terminal/session" +) // SessionDB is an interface that abstracts the database operations needed for // the privacy mapper to function. @@ -8,5 +12,5 @@ type SessionDB interface { session.IDToGroupIndex // GetSessionByID returns the session for a specific id. - GetSessionByID(session.ID) (*session.Session, error) + GetSessionByID(context.Context, session.ID) (*session.Session, error) } diff --git a/firewalldb/mock.go b/firewalldb/mock.go index 4030dde3f..0213de864 100644 --- a/firewalldb/mock.go +++ b/firewalldb/mock.go @@ -1,6 +1,7 @@ package firewalldb import ( + "context" "fmt" "github.com/lightninglabs/lightning-terminal/session" @@ -33,7 +34,9 @@ func (m *mockSessionDB) AddPair(sessionID, groupID session.ID) { } // GetGroupID returns the group ID for the given session ID. -func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) { +func (m *mockSessionDB) GetGroupID(_ context.Context, sessionID session.ID) ( + session.ID, error) { + id, ok := m.sessionToGroupID[sessionID] if !ok { return session.ID{}, fmt.Errorf("no group ID found for " + @@ -44,7 +47,9 @@ func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) { } // GetSessionIDs returns the set of session IDs that are in the group -func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) { +func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) ( + []session.ID, error) { + ids, ok := m.groupToSessionIDs[groupID] if !ok { return nil, fmt.Errorf("no session IDs found for group ID") @@ -54,8 +59,8 @@ func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) } // GetSessionByID returns the session for a specific id. -func (m *mockSessionDB) GetSessionByID(sessionID session.ID) (*session.Session, - error) { +func (m *mockSessionDB) GetSessionByID(_ context.Context, + sessionID session.ID) (*session.Session, error) { s, ok := m.sessionToGroupID[sessionID] if !ok { diff --git a/session/interface.go b/session/interface.go index 7e84bf07e..a861f7e34 100644 --- a/session/interface.go +++ b/session/interface.go @@ -1,6 +1,7 @@ package session import ( + "context" "fmt" "time" @@ -260,11 +261,11 @@ func WithMacaroonRecipe(caveats []macaroon.Caveat, perms []bakery.Op) Option { // IDToGroupIndex defines an interface for the session ID to group ID index. type IDToGroupIndex interface { // GetGroupID will return the group ID for the given session ID. - GetGroupID(sessionID ID) (ID, error) + GetGroupID(ctx context.Context, sessionID ID) (ID, error) // GetSessionIDs will return the set of session IDs that are in the // group with the given ID. - GetSessionIDs(groupID ID) ([]ID, error) + GetSessionIDs(ctx context.Context, groupID ID) ([]ID, error) } // Store is the interface a persistent storage must implement for storing and @@ -273,37 +274,39 @@ type Store interface { // NewSession creates a new session with the given user-defined // parameters. The session will remain in the StateReserved state until // ShiftState is called to update the state. - NewSession(label string, typ Type, expiry time.Time, serverAddr string, - opts ...Option) (*Session, error) + NewSession(ctx context.Context, label string, typ Type, + expiry time.Time, serverAddr string, opts ...Option) (*Session, + error) // GetSession fetches the session with the given key. - GetSession(key *btcec.PublicKey) (*Session, error) + GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error) // ListAllSessions returns all sessions currently known to the store. - ListAllSessions() ([]*Session, error) + ListAllSessions(ctx context.Context) ([]*Session, error) // ListSessionsByType returns all sessions of the given type. - ListSessionsByType(t Type) ([]*Session, error) + ListSessionsByType(ctx context.Context, t Type) ([]*Session, error) // ListSessionsByState returns all sessions currently known to the store // that are in the given states. - ListSessionsByState(...State) ([]*Session, error) + ListSessionsByState(ctx context.Context, state ...State) ([]*Session, + error) // UpdateSessionRemotePubKey can be used to add the given remote pub key // to the session with the given local pub key. - UpdateSessionRemotePubKey(localPubKey, + UpdateSessionRemotePubKey(ctx context.Context, localPubKey, remotePubKey *btcec.PublicKey) error // GetSessionByID fetches the session with the given ID. - GetSessionByID(id ID) (*Session, error) + GetSessionByID(ctx context.Context, id ID) (*Session, error) // DeleteReservedSessions deletes all sessions that are in the // StateReserved state. - DeleteReservedSessions() error + DeleteReservedSessions(ctx context.Context) error // ShiftState updates the state of the session with the given ID to the // "dest" state. - ShiftState(id ID, dest State) error + ShiftState(ctx context.Context, id ID, dest State) error IDToGroupIndex } diff --git a/session/kvdb_store.go b/session/kvdb_store.go index 216d0fc1e..69b2eac87 100644 --- a/session/kvdb_store.go +++ b/session/kvdb_store.go @@ -2,6 +2,7 @@ package session import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -185,8 +186,8 @@ func getSessionKey(session *Session) []byte { // ShiftState is called with StateCreated. // // NOTE: this is part of the Store interface. -func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, - serverAddr string, opts ...Option) (*Session, error) { +func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type, + expiry time.Time, serverAddr string, opts ...Option) (*Session, error) { var session *Session err := db.Update(func(tx *bbolt.Tx) error { @@ -285,7 +286,7 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time, // to the session with the given local pub key. // // NOTE: this is part of the Store interface. -func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey, +func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey, remotePubKey *btcec.PublicKey) error { key := localPubKey.SerializeCompressed() @@ -318,7 +319,9 @@ func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey, // GetSession fetches the session with the given key. // // NOTE: this is part of the Store interface. -func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) { +func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) ( + *Session, error) { + var session *Session err := db.View(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) @@ -348,7 +351,7 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) { // ListAllSessions returns all sessions currently known to the store. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ListAllSessions() ([]*Session, error) { +func (db *BoltStore) ListAllSessions(_ context.Context) ([]*Session, error) { return db.listSessions(func(s *Session) bool { return true }) @@ -358,7 +361,9 @@ func (db *BoltStore) ListAllSessions() ([]*Session, error) { // have the given type. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) { +func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session, + error) { + return db.listSessions(func(s *Session) bool { return s.Type == t }) @@ -368,7 +373,9 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) { // are in the given states. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) { +func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) ( + []*Session, error) { + return db.listSessions(func(s *Session) bool { for _, state := range states { if s.State == state { @@ -429,7 +436,7 @@ func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session, // state. // // NOTE: this is part of the Store interface. -func (db *BoltStore) DeleteReservedSessions() error { +func (db *BoltStore) DeleteReservedSessions(_ context.Context) error { return db.Update(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) if err != nil { @@ -522,7 +529,7 @@ func (db *BoltStore) DeleteReservedSessions() error { // state. // // NOTE: this is part of the Store interface. -func (db *BoltStore) ShiftState(id ID, dest State) error { +func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error { return db.Update(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) if err != nil { @@ -562,7 +569,9 @@ func (db *BoltStore) ShiftState(id ID, dest State) error { // GetSessionByID fetches the session with the given ID. // // NOTE: this is part of the Store interface. -func (db *BoltStore) GetSessionByID(id ID) (*Session, error) { +func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session, + error) { + var session *Session err := db.View(func(tx *bbolt.Tx) error { sessionBucket, err := getBucket(tx, sessionBucketKey) @@ -615,7 +624,7 @@ func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey, // GetGroupID will return the group ID for the given session ID. // // NOTE: this is part of the IDToGroupIndex interface. -func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) { +func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) { var groupID ID err := db.View(func(tx *bbolt.Tx) error { sessionBkt, err := getBucket(tx, sessionBucketKey) @@ -655,7 +664,9 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) { // group with the given ID. // // NOTE: this is part of the IDToGroupIndex interface. -func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) { +func (db *BoltStore) GetSessionIDs(_ context.Context, groupID ID) ([]ID, + error) { + var ( sessionIDs []ID err error diff --git a/session/server.go b/session/server.go index ae7a50121..22de3bf8d 100644 --- a/session/server.go +++ b/session/server.go @@ -1,6 +1,7 @@ package session import ( + "context" "crypto/tls" "fmt" "sync" @@ -8,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightninglabs/lightning-node-connect/mailbox" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -21,8 +23,9 @@ type GRPCServerCreator func(opts ...grpc.ServerOption) *grpc.Server type mailboxSession struct { server *grpc.Server - wg sync.WaitGroup - quit chan struct{} + cancel fn.Option[context.CancelFunc] + wg sync.WaitGroup + quit chan struct{} } func newMailboxSession() *mailboxSession { @@ -33,7 +36,8 @@ func newMailboxSession() *mailboxSession { func (m *mailboxSession) start(session *Session, serverCreator GRPCServerCreator, authData []byte, - onUpdate func(local, remote *btcec.PublicKey) error, + onUpdate func(ctx context.Context, local, + remote *btcec.PublicKey) error, onNewStatus func(s mailbox.ServerStatus)) error { tlsConfig := &tls.Config{} @@ -43,10 +47,13 @@ func (m *mailboxSession) start(session *Session, ecdh := &keychain.PrivKeyECDH{PrivKey: session.LocalPrivateKey} + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = fn.Some(cancel) + keys := mailbox.NewConnData( ecdh, session.RemotePublicKey, session.PairingSecret[:], authData, func(key *btcec.PublicKey) error { - return onUpdate(session.LocalPublicKey, key) + return onUpdate(ctx, session.LocalPublicKey, key) }, nil, ) @@ -81,6 +88,7 @@ func (m *mailboxSession) run(mailboxServer *mailbox.Server) { } func (m *mailboxSession) stop() { + m.cancel.WhenSome(func(fn context.CancelFunc) { fn() }) m.server.Stop() close(m.quit) m.wg.Wait() @@ -104,7 +112,8 @@ func NewServer(serverCreator GRPCServerCreator) *Server { } func (s *Server) StartSession(session *Session, authData []byte, - onUpdate func(local, remote *btcec.PublicKey) error, + onUpdate func(ctx context.Context, local, + remote *btcec.PublicKey) error, onNewStatus func(s mailbox.ServerStatus)) (chan struct{}, error) { s.activeSessionsMtx.Lock() diff --git a/session/store_test.go b/session/store_test.go index 07695fe4e..a3c6c4289 100644 --- a/session/store_test.go +++ b/session/store_test.go @@ -1,6 +1,7 @@ package session import ( + "context" "testing" "time" @@ -14,12 +15,15 @@ var testTime = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) // TestBasicSessionStore tests the basic getters and setters of the session // store. func TestBasicSessionStore(t *testing.T) { + t.Parallel() + ctx := context.Background() + // Set up a new DB. clock := clock.NewTestClock(testTime) db := NewTestDB(t, clock) // Try fetch a session that doesn't exist yet. - _, err := db.GetSessionByID(ID{1, 3, 4, 4}) + _, err := db.GetSessionByID(ctx, ID{1, 3, 4, 4}) require.ErrorIs(t, err, ErrSessionNotFound) // Reserve a session. This should succeed. @@ -27,22 +31,22 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) // Show that the session starts in the reserved state. - s1, err = db.GetSessionByID(s1.ID) + s1, err = db.GetSessionByID(ctx, s1.ID) require.NoError(t, err) require.Equal(t, StateReserved, s1.State) // Move session 1 to the created state. This should succeed. - err = db.ShiftState(s1.ID, StateCreated) + err = db.ShiftState(ctx, s1.ID, StateCreated) require.NoError(t, err) // Show that the session is now in the created state. - s1, err = db.GetSessionByID(s1.ID) + s1, err = db.GetSessionByID(ctx, s1.ID) require.NoError(t, err) require.Equal(t, StateCreated, s1.State) // Trying to move session 1 again should have no effect since it is // already in the created state. - require.NoError(t, db.ShiftState(s1.ID, StateCreated)) + require.NoError(t, db.ShiftState(ctx, s1.ID, StateCreated)) // Reserve and create a few more sessions. We increment the time by one // second between each session to ensure that the created at time is @@ -54,35 +58,35 @@ func TestBasicSessionStore(t *testing.T) { s3 := createSession(t, db, "session 3", withType(TypeAutopilot)) // Test the ListSessionsByType method. - sessions, err := db.ListSessionsByType(TypeMacaroonAdmin) + sessions, err := db.ListSessionsByType(ctx, TypeMacaroonAdmin) require.NoError(t, err) require.Equal(t, 2, len(sessions)) assertEqualSessions(t, s1, sessions[0]) assertEqualSessions(t, s2, sessions[1]) - sessions, err = db.ListSessionsByType(TypeAutopilot) + sessions, err = db.ListSessionsByType(ctx, TypeAutopilot) require.NoError(t, err) require.Equal(t, 1, len(sessions)) assertEqualSessions(t, s3, sessions[0]) - sessions, err = db.ListSessionsByType(TypeMacaroonReadonly) + sessions, err = db.ListSessionsByType(ctx, TypeMacaroonReadonly) require.NoError(t, err) require.Empty(t, sessions) // Ensure that we can retrieve each session by both its local pub key // and by its ID. for _, s := range []*Session{s1, s2, s3} { - session, err := db.GetSession(s.LocalPublicKey) + session, err := db.GetSession(ctx, s.LocalPublicKey) require.NoError(t, err) assertEqualSessions(t, s, session) - session, err = db.GetSessionByID(s.ID) + session, err = db.GetSessionByID(ctx, s.ID) require.NoError(t, err) assertEqualSessions(t, s, session) } // Fetch session 1 and assert that it currently has no remote pub key. - session1, err := db.GetSession(s1.LocalPublicKey) + session1, err := db.GetSession(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Nil(t, session1.RemotePublicKey) @@ -91,11 +95,13 @@ func TestBasicSessionStore(t *testing.T) { require.NoError(t, err) remotePub := remotePriv.PubKey() - err = db.UpdateSessionRemotePubKey(session1.LocalPublicKey, remotePub) + err = db.UpdateSessionRemotePubKey( + ctx, session1.LocalPublicKey, remotePub, + ) require.NoError(t, err) // Assert that the session now does have the remote pub key. - session1, err = db.GetSession(s1.LocalPublicKey) + session1, err = db.GetSession(ctx, s1.LocalPublicKey) require.NoError(t, err) require.True(t, remotePub.IsEqual(session1.RemotePublicKey)) @@ -103,13 +109,13 @@ func TestBasicSessionStore(t *testing.T) { require.Equal(t, session1.State, StateCreated) // Now revoke the session and assert that the state is revoked. - require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) - s1, err = db.GetSession(s1.LocalPublicKey) + require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked)) + s1, err = db.GetSession(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, s1.State, StateRevoked) // Test that ListAllSessions works. - sessions, err = db.ListAllSessions() + sessions, err = db.ListAllSessions(ctx) require.NoError(t, err) require.Equal(t, 3, len(sessions)) assertEqualSessions(t, s1, sessions[0]) @@ -117,29 +123,29 @@ func TestBasicSessionStore(t *testing.T) { assertEqualSessions(t, s3, sessions[2]) // Test that ListSessionsByState works. - sessions, err = db.ListSessionsByState(StateRevoked) + sessions, err = db.ListSessionsByState(ctx, StateRevoked) require.NoError(t, err) require.Equal(t, 1, len(sessions)) assertEqualSessions(t, s1, sessions[0]) - sessions, err = db.ListSessionsByState(StateCreated) + sessions, err = db.ListSessionsByState(ctx, StateCreated) require.NoError(t, err) require.Equal(t, 2, len(sessions)) assertEqualSessions(t, s2, sessions[0]) assertEqualSessions(t, s3, sessions[1]) - sessions, err = db.ListSessionsByState(StateCreated, StateRevoked) + sessions, err = db.ListSessionsByState(ctx, StateCreated, StateRevoked) require.NoError(t, err) require.Equal(t, 3, len(sessions)) assertEqualSessions(t, s1, sessions[0]) assertEqualSessions(t, s2, sessions[1]) assertEqualSessions(t, s3, sessions[2]) - sessions, err = db.ListSessionsByState() + sessions, err = db.ListSessionsByState(ctx) require.NoError(t, err) require.Empty(t, sessions) - sessions, err = db.ListSessionsByState(StateReserved) + sessions, err = db.ListSessionsByState(ctx, StateReserved) require.NoError(t, err) require.Empty(t, sessions) @@ -147,9 +153,9 @@ func TestBasicSessionStore(t *testing.T) { // // Calling DeleteReservedSessions should have no effect yet since none // of the sessions are reserved. - require.NoError(t, db.DeleteReservedSessions()) + require.NoError(t, db.DeleteReservedSessions(ctx)) - sessions, err = db.ListSessionsByState(StateReserved) + sessions, err = db.ListSessionsByState(ctx, StateReserved) require.NoError(t, err) require.Empty(t, sessions) @@ -159,34 +165,34 @@ func TestBasicSessionStore(t *testing.T) { ) require.NoError(t, err) - sessions, err = db.ListSessionsByState(StateReserved) + sessions, err = db.ListSessionsByState(ctx, StateReserved) require.NoError(t, err) require.Equal(t, 1, len(sessions)) assertEqualSessions(t, s4, sessions[0]) // Show that the group ID/session ID index has also been populated with // this session. - groupID, err := db.GetGroupID(s4.ID) + groupID, err := db.GetGroupID(ctx, s4.ID) require.NoError(t, err) require.Equal(t, s1.ID, groupID) - sessIDs, err := db.GetSessionIDs(s4.GroupID) + sessIDs, err := db.GetSessionIDs(ctx, s4.GroupID) require.NoError(t, err) require.ElementsMatch(t, []ID{s4.ID, s1.ID}, sessIDs) // Now delete the reserved session and show that it is no longer in the // database and no longer in the group ID/session ID index. - require.NoError(t, db.DeleteReservedSessions()) + require.NoError(t, db.DeleteReservedSessions(ctx)) - sessions, err = db.ListSessionsByState(StateReserved) + sessions, err = db.ListSessionsByState(ctx, StateReserved) require.NoError(t, err) require.Empty(t, sessions) - _, err = db.GetGroupID(s4.ID) + _, err = db.GetGroupID(ctx, s4.ID) require.ErrorIs(t, err, ErrUnknownGroup) // Only session 1 should remain in this group. - sessIDs, err = db.GetSessionIDs(s4.GroupID) + sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID) require.NoError(t, err) require.ElementsMatch(t, []ID{s1.ID}, sessIDs) } @@ -194,6 +200,7 @@ func TestBasicSessionStore(t *testing.T) { // TestLinkingSessions tests that session linking works as expected. func TestLinkingSessions(t *testing.T) { t.Parallel() + ctx := context.Background() // Set up a new DB. clock := clock.NewTestClock(testTime) @@ -219,7 +226,7 @@ func TestLinkingSessions(t *testing.T) { require.ErrorIs(t, err, ErrSessionsInGroupStillActive) // Revoke the first session. - require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) + require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked)) // Persisting the second linked session should now work. _, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID)) @@ -231,6 +238,7 @@ func TestLinkingSessions(t *testing.T) { // of the GetGroupID and GetSessionIDs methods. func TestLinkedSessions(t *testing.T) { t.Parallel() + ctx := context.Background() // Set up a new DB. clock := clock.NewTestClock(testTime) @@ -242,48 +250,51 @@ func TestLinkedSessions(t *testing.T) { // first session. s1 := createSession(t, db, "session 1") - require.NoError(t, db.ShiftState(s1.ID, StateRevoked)) + require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked)) s2 := createSession(t, db, "session 2", withLinkedGroupID(&s1.GroupID)) - require.NoError(t, db.ShiftState(s2.ID, StateRevoked)) + require.NoError(t, db.ShiftState(ctx, s2.ID, StateRevoked)) s3 := createSession(t, db, "session 3", withLinkedGroupID(&s2.GroupID)) // Assert that the session ID to group ID index works as expected. for _, s := range []*Session{s1, s2, s3} { - groupID, err := db.GetGroupID(s.ID) + groupID, err := db.GetGroupID(ctx, s.ID) require.NoError(t, err) require.Equal(t, s1.ID, groupID) require.Equal(t, s.GroupID, groupID) } // Assert that the group ID to session ID index works as expected. - sIDs, err := db.GetSessionIDs(s1.GroupID) + sIDs, err := db.GetSessionIDs(ctx, s1.GroupID) require.NoError(t, err) require.EqualValues(t, []ID{s1.ID, s2.ID, s3.ID}, sIDs) // To ensure that different groups don't interfere with each other, // let's add another set of linked sessions not linked to the first. s4 := createSession(t, db, "session 4") - require.NoError(t, db.ShiftState(s4.ID, StateRevoked)) + require.NoError(t, db.ShiftState(ctx, s4.ID, StateRevoked)) s5 := createSession(t, db, "session 5", withLinkedGroupID(&s4.GroupID)) require.NotEqual(t, s4.GroupID, s1.GroupID) // Assert that the session ID to group ID index works as expected. for _, s := range []*Session{s4, s5} { - groupID, err := db.GetGroupID(s.ID) + groupID, err := db.GetGroupID(ctx, s.ID) require.NoError(t, err) require.Equal(t, s4.ID, groupID) require.Equal(t, s.GroupID, groupID) } // Assert that the group ID to session ID index works as expected. - sIDs, err = db.GetSessionIDs(s5.GroupID) + sIDs, err = db.GetSessionIDs(ctx, s5.GroupID) require.NoError(t, err) require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs) } // TestStateShift tests that the ShiftState method works as expected. func TestStateShift(t *testing.T) { + t.Parallel() + ctx := context.Background() + // Set up a new DB. clock := clock.NewTestClock(testTime) db := NewTestDB(t, clock) @@ -293,18 +304,18 @@ func TestStateShift(t *testing.T) { // Check that the session is in the StateCreated state. Also check that // the "RevokedAt" time has not yet been set. - s1, err := db.GetSession(s1.LocalPublicKey) + s1, err := db.GetSession(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, StateCreated, s1.State) require.Equal(t, time.Time{}, s1.RevokedAt) // Shift the state of the session to StateRevoked. - err = db.ShiftState(s1.ID, StateRevoked) + err = db.ShiftState(ctx, s1.ID, StateRevoked) require.NoError(t, err) // This should have worked. Since it is now in a terminal state, the // "RevokedAt" time should be set. - s1, err = db.GetSession(s1.LocalPublicKey) + s1, err = db.GetSession(ctx, s1.LocalPublicKey) require.NoError(t, err) require.Equal(t, StateRevoked, s1.State) require.True(t, clock.Now().Equal(s1.RevokedAt)) @@ -314,13 +325,13 @@ func TestStateShift(t *testing.T) { // should not have changed though. prevTime := clock.Now() clock.SetTime(prevTime.Add(time.Second)) - err = db.ShiftState(s1.ID, StateRevoked) + err = db.ShiftState(ctx, s1.ID, StateRevoked) require.NoError(t, err) require.True(t, prevTime.Equal(s1.RevokedAt)) // Trying to shift the state from a terminal state back to StateCreated // should also fail since this is not a legal state transition. - err = db.ShiftState(s1.ID, StateCreated) + err = db.ShiftState(ctx, s1.ID, StateCreated) require.ErrorContains(t, err, "illegal session state transition") } @@ -360,7 +371,8 @@ func reserveSession(db Store, label string, mod(opts) } - return db.NewSession(label, opts.sessType, + return db.NewSession( + context.Background(), label, opts.sessType, time.Date(99999, 1, 1, 0, 0, 0, 0, time.UTC), "foo.bar.baz:1234", WithDevServer(), @@ -375,10 +387,10 @@ func createSession(t *testing.T, db Store, label string, s, err := reserveSession(db, label, mods...) require.NoError(t, err) - err = db.ShiftState(s.ID, StateCreated) + err = db.ShiftState(context.Background(), s.ID, StateCreated) require.NoError(t, err) - s, err = db.GetSessionByID(s.ID) + s, err = db.GetSessionByID(context.Background(), s.ID) require.NoError(t, err) return s diff --git a/session_rpcserver.go b/session_rpcserver.go index 46d8d162b..7362d8c7f 100644 --- a/session_rpcserver.go +++ b/session_rpcserver.go @@ -96,15 +96,14 @@ func newSessionRPCServer(cfg *sessionRpcServerConfig) (*sessionRpcServer, // requests. This includes resuming all non-revoked sessions. func (s *sessionRpcServer) start(ctx context.Context) error { // Delete all sessions in the Reserved state. - err := s.cfg.db.DeleteReservedSessions() + err := s.cfg.db.DeleteReservedSessions(ctx) if err != nil { return fmt.Errorf("error deleting reserved sessions: %v", err) } // Start up all previously created sessions. sessions, err := s.cfg.db.ListSessionsByState( - session.StateCreated, - session.StateInUse, + ctx, session.StateCreated, session.StateInUse, ) if err != nil { return fmt.Errorf("error listing sessions: %v", err) @@ -150,7 +149,7 @@ func (s *sessionRpcServer) start(ctx context.Context) error { if perm { err := s.cfg.db.ShiftState( - sess.ID, session.StateRevoked, + ctx, sess.ID, session.StateRevoked, ) if err != nil { log.Errorf("error revoking "+ @@ -317,13 +316,14 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, } sess, err := s.cfg.db.NewSession( - req.Label, typ, expiry, req.MailboxServerAddr, sessOpts..., + ctx, req.Label, typ, expiry, req.MailboxServerAddr, + sessOpts..., ) if err != nil { return nil, fmt.Errorf("error creating new session: %v", err) } - err = s.cfg.db.ShiftState(sess.ID, session.StateCreated) + err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateCreated) if err != nil { return nil, fmt.Errorf("error shifting session state to "+ "Created: %v", err) @@ -335,7 +335,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context, // Re-fetch the session to get the latest state of it before marshaling // it. - sess, err = s.cfg.db.GetSessionByID(sess.ID) + sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID) if err != nil { return nil, fmt.Errorf("error fetching session: %v", err) } @@ -362,7 +362,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context, log.Debugf("Not resuming session %x with expiry %s", pubKeyBytes, sess.Expiry) - err := s.cfg.db.ShiftState(sess.ID, session.StateExpired) + err := s.cfg.db.ShiftState(ctx, sess.ID, session.StateExpired) if err != nil { return fmt.Errorf("error revoking session: %v", err) } @@ -440,7 +440,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context, "passed. Revoking session", pubKeyBytes) return s.cfg.db.ShiftState( - sess.ID, session.StateRevoked, + ctx, sess.ID, session.StateRevoked, ) } @@ -520,7 +520,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context, log.Debugf("Error stopping session: %v", err) } - err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked) + err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateRevoked) if err != nil { log.Debugf("error revoking session: %v", err) } @@ -530,10 +530,10 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context, } // ListSessions returns all sessions known to the session store. -func (s *sessionRpcServer) ListSessions(_ context.Context, +func (s *sessionRpcServer) ListSessions(ctx context.Context, _ *litrpc.ListSessionsRequest) (*litrpc.ListSessionsResponse, error) { - sessions, err := s.cfg.db.ListAllSessions() + sessions, err := s.cfg.db.ListAllSessions(ctx) if err != nil { return nil, fmt.Errorf("error fetching sessions: %v", err) } @@ -562,12 +562,12 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context, return nil, fmt.Errorf("error parsing public key: %v", err) } - sess, err := s.cfg.db.GetSession(pubKey) + sess, err := s.cfg.db.GetSession(ctx, pubKey) if err != nil { return nil, fmt.Errorf("error fetching session: %v", err) } - err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked) + err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateRevoked) if err != nil { return nil, fmt.Errorf("error revoking session: %v", err) } @@ -587,7 +587,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context, // PrivacyMapConversion can be used map real values to their pseudo counterpart // and vice versa. -func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context, +func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context, req *litrpc.PrivacyMapConversionRequest) ( *litrpc.PrivacyMapConversionResponse, error) { @@ -606,7 +606,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context, return nil, err } - groupID, err = s.cfg.db.GetGroupID(sessionID) + groupID, err = s.cfg.db.GetGroupID(ctx, sessionID) if err != nil { return nil, err } @@ -643,7 +643,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context, // stored if the actions are interceptor actions, otherwise only the URI and // timestamp of the actions will be stored. The "full" mode will persist all // request data for all actions. -func (s *sessionRpcServer) ListActions(_ context.Context, +func (s *sessionRpcServer) ListActions(ctx context.Context, req *litrpc.ListActionsRequest) (*litrpc.ListActionsResponse, error) { // If no maximum number of actions is given, use a default of 100. @@ -746,7 +746,7 @@ func (s *sessionRpcServer) ListActions(_ context.Context, return nil, err } - actions, err = db.ListGroupActions(groupID, filterFn) + actions, err = db.ListGroupActions(ctx, groupID, filterFn) if err != nil { return nil, err } @@ -867,7 +867,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, copy(groupID[:], req.LinkedGroupId) // Check that the group actually does exist. - groupSess, err := s.cfg.db.GetSessionByID(groupID) + groupSess, err := s.cfg.db.GetSessionByID(ctx, groupID) if err != nil { return nil, err } @@ -1148,8 +1148,8 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, } sess, err := s.cfg.db.NewSession( - req.Label, session.TypeAutopilot, expiry, req.MailboxServerAddr, - sessOpts..., + ctx, req.Label, session.TypeAutopilot, expiry, + req.MailboxServerAddr, sessOpts..., ) if err != nil { return nil, fmt.Errorf("error creating new session: %v", err) @@ -1230,7 +1230,9 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, "autopilot server: %v", err) } - err = s.cfg.db.UpdateSessionRemotePubKey(sess.LocalPublicKey, remoteKey) + err = s.cfg.db.UpdateSessionRemotePubKey( + ctx, sess.LocalPublicKey, remoteKey, + ) if err != nil { return nil, fmt.Errorf("error setting remote pubkey: %v", err) } @@ -1240,7 +1242,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // We only activate the session if the Autopilot server registration // was successful. - err = s.cfg.db.ShiftState(sess.ID, session.StateCreated) + err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateCreated) if err != nil { return nil, fmt.Errorf("error shifting session state to "+ "Created: %v", err) @@ -1252,7 +1254,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // Re-fetch the session to get the latest state of it before marshaling // it. - sess, err = s.cfg.db.GetSessionByID(sess.ID) + sess, err = s.cfg.db.GetSessionByID(ctx, sess.ID) if err != nil { return nil, fmt.Errorf("error fetching session: %v", err) } @@ -1269,11 +1271,11 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context, // ListAutopilotSessions fetches and returns all the sessions from the DB that // are of type TypeAutopilot. -func (s *sessionRpcServer) ListAutopilotSessions(_ context.Context, +func (s *sessionRpcServer) ListAutopilotSessions(ctx context.Context, _ *litrpc.ListAutopilotSessionsRequest) ( *litrpc.ListAutopilotSessionsResponse, error) { - sessions, err := s.cfg.db.ListSessionsByType(session.TypeAutopilot) + sessions, err := s.cfg.db.ListSessionsByType(ctx, session.TypeAutopilot) if err != nil { return nil, fmt.Errorf("error fetching sessions: %v", err) } @@ -1302,7 +1304,7 @@ func (s *sessionRpcServer) RevokeAutopilotSession(ctx context.Context, return nil, fmt.Errorf("error parsing public key: %v", err) } - sess, err := s.cfg.db.GetSession(pubKey) + sess, err := s.cfg.db.GetSession(ctx, pubKey) if err != nil { return nil, err }