Skip to content

[sql-16] sessions: update Store interface methods to take a context #986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 28, 2025
4 changes: 2 additions & 2 deletions firewall/privacy_mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
uri string, req proto.Message, sessionID session.ID) (proto.Message,
error) {

session, err := p.sessionDB.GetSessionByID(sessionID)
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (p *PrivacyMapper) checkAndReplaceIncomingRequest(ctx context.Context,
func (p *PrivacyMapper) replaceOutgoingResponse(ctx context.Context, uri string,
resp proto.Message, sessionID session.ID) (proto.Message, error) {

session, err := p.sessionDB.GetSessionByID(sessionID)
session, err := p.sessionDB.GetSessionByID(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion firewall/rule_enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string,
return nil, err
}

session, err := r.sessionDB.GetSessionByID(sessionID)
session, err := r.sessionDB.GetSessionByID(ctx, sessionID)
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions firewalldb/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID,
// pass the filterFn requirements.
//
// TODO: update to allow for pagination.
func (db *DB) ListGroupActions(groupID session.ID,
func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID,
filterFn ListActionsFilterFn) ([]*Action, error) {

if filterFn == nil {
Expand All @@ -400,7 +400,7 @@ func (db *DB) ListGroupActions(groupID session.ID,
}
}

sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID)
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -629,11 +629,11 @@ type groupActionsReadDB struct {
var _ ActionsDB = (*groupActionsReadDB)(nil)

// ListActions will return all the Actions for a particular group.
func (s *groupActionsReadDB) ListActions(_ context.Context) ([]*RuleAction,
func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction,
error) {

sessionActions, err := s.db.ListGroupActions(
s.groupID, func(a *Action, _ bool) (bool, bool) {
ctx, s.groupID, func(a *Action, _ bool) (bool, bool) {
return a.State == ActionStateDone, true
},
)
Expand All @@ -660,11 +660,11 @@ var _ ActionsDB = (*groupFeatureActionsReadDB)(nil)

// ListActions will return all the Actions for a particular group that were
// executed by a particular feature.
func (a *groupFeatureActionsReadDB) ListActions(_ context.Context) (
func (a *groupFeatureActionsReadDB) ListActions(ctx context.Context) (
[]*RuleAction, error) {

featureActions, err := a.db.ListGroupActions(
a.groupID, func(action *Action, _ bool) (bool, bool) {
ctx, a.groupID, func(action *Action, _ bool) (bool, bool) {
return action.State == ActionStateDone &&
action.FeatureName == a.featureName, true
},
Expand Down
10 changes: 7 additions & 3 deletions firewalldb/actions_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package firewalldb

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -342,6 +343,9 @@ func TestListActions(t *testing.T) {
// TestListGroupActions tests that the ListGroupActions correctly returns all
// actions in a particular session group.
func TestListGroupActions(t *testing.T) {
t.Parallel()
ctx := context.Background()

group1 := intToSessionID(0)

// Link session 1 and session 2 to group 1.
Expand All @@ -356,7 +360,7 @@ func TestListGroupActions(t *testing.T) {
})

// There should not be any actions in group 1 yet.
al, err := db.ListGroupActions(group1, nil)
al, err := db.ListGroupActions(ctx, group1, nil)
require.NoError(t, err)
require.Empty(t, al)

Expand All @@ -365,7 +369,7 @@ func TestListGroupActions(t *testing.T) {
require.NoError(t, err)

// There should now be one action in the group.
al, err = db.ListGroupActions(group1, nil)
al, err = db.ListGroupActions(ctx, group1, nil)
require.NoError(t, err)
require.Len(t, al, 1)
require.Equal(t, sessionID1, al[0].SessionID)
Expand All @@ -375,7 +379,7 @@ func TestListGroupActions(t *testing.T) {
require.NoError(t, err)

// There should now be actions in the group.
al, err = db.ListGroupActions(group1, nil)
al, err = db.ListGroupActions(ctx, group1, nil)
require.NoError(t, err)
require.Len(t, al, 2)
require.Equal(t, sessionID1, al[0].SessionID)
Expand Down
8 changes: 6 additions & 2 deletions firewalldb/interface.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package firewalldb

import "github.com/lightninglabs/lightning-terminal/session"
import (
"context"

"github.com/lightninglabs/lightning-terminal/session"
)

// SessionDB is an interface that abstracts the database operations needed for
// the privacy mapper to function.
type SessionDB interface {
session.IDToGroupIndex

// GetSessionByID returns the session for a specific id.
GetSessionByID(session.ID) (*session.Session, error)
GetSessionByID(context.Context, session.ID) (*session.Session, error)
}
13 changes: 9 additions & 4 deletions firewalldb/mock.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package firewalldb

import (
"context"
"fmt"

"github.com/lightninglabs/lightning-terminal/session"
Expand Down Expand Up @@ -33,7 +34,9 @@ func (m *mockSessionDB) AddPair(sessionID, groupID session.ID) {
}

// GetGroupID returns the group ID for the given session ID.
func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
func (m *mockSessionDB) GetGroupID(_ context.Context, sessionID session.ID) (
session.ID, error) {

id, ok := m.sessionToGroupID[sessionID]
if !ok {
return session.ID{}, fmt.Errorf("no group ID found for " +
Expand All @@ -44,7 +47,9 @@ func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
}

// GetSessionIDs returns the set of session IDs that are in the group
func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) {
func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) (
[]session.ID, error) {

ids, ok := m.groupToSessionIDs[groupID]
if !ok {
return nil, fmt.Errorf("no session IDs found for group ID")
Expand All @@ -54,8 +59,8 @@ func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error)
}

// GetSessionByID returns the session for a specific id.
func (m *mockSessionDB) GetSessionByID(sessionID session.ID) (*session.Session,
error) {
func (m *mockSessionDB) GetSessionByID(_ context.Context,
sessionID session.ID) (*session.Session, error) {

s, ok := m.sessionToGroupID[sessionID]
if !ok {
Expand Down
27 changes: 15 additions & 12 deletions session/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package session

import (
"context"
"fmt"
"time"

Expand Down Expand Up @@ -260,11 +261,11 @@ func WithMacaroonRecipe(caveats []macaroon.Caveat, perms []bakery.Op) Option {
// IDToGroupIndex defines an interface for the session ID to group ID index.
type IDToGroupIndex interface {
// GetGroupID will return the group ID for the given session ID.
GetGroupID(sessionID ID) (ID, error)
GetGroupID(ctx context.Context, sessionID ID) (ID, error)

// GetSessionIDs will return the set of session IDs that are in the
// group with the given ID.
GetSessionIDs(groupID ID) ([]ID, error)
GetSessionIDs(ctx context.Context, groupID ID) ([]ID, error)
}

// Store is the interface a persistent storage must implement for storing and
Expand All @@ -273,37 +274,39 @@ type Store interface {
// NewSession creates a new session with the given user-defined
// parameters. The session will remain in the StateReserved state until
// ShiftState is called to update the state.
NewSession(label string, typ Type, expiry time.Time, serverAddr string,
opts ...Option) (*Session, error)
NewSession(ctx context.Context, label string, typ Type,
expiry time.Time, serverAddr string, opts ...Option) (*Session,
error)

// GetSession fetches the session with the given key.
GetSession(key *btcec.PublicKey) (*Session, error)
GetSession(ctx context.Context, key *btcec.PublicKey) (*Session, error)

// ListAllSessions returns all sessions currently known to the store.
ListAllSessions() ([]*Session, error)
ListAllSessions(ctx context.Context) ([]*Session, error)

// ListSessionsByType returns all sessions of the given type.
ListSessionsByType(t Type) ([]*Session, error)
ListSessionsByType(ctx context.Context, t Type) ([]*Session, error)

// ListSessionsByState returns all sessions currently known to the store
// that are in the given states.
ListSessionsByState(...State) ([]*Session, error)
ListSessionsByState(ctx context.Context, state ...State) ([]*Session,
error)

// UpdateSessionRemotePubKey can be used to add the given remote pub key
// to the session with the given local pub key.
UpdateSessionRemotePubKey(localPubKey,
UpdateSessionRemotePubKey(ctx context.Context, localPubKey,
remotePubKey *btcec.PublicKey) error

// GetSessionByID fetches the session with the given ID.
GetSessionByID(id ID) (*Session, error)
GetSessionByID(ctx context.Context, id ID) (*Session, error)

// DeleteReservedSessions deletes all sessions that are in the
// StateReserved state.
DeleteReservedSessions() error
DeleteReservedSessions(ctx context.Context) error

// ShiftState updates the state of the session with the given ID to the
// "dest" state.
ShiftState(id ID, dest State) error
ShiftState(ctx context.Context, id ID, dest State) error

IDToGroupIndex
}
35 changes: 23 additions & 12 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package session

import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
Expand Down Expand Up @@ -185,8 +186,8 @@ func getSessionKey(session *Session) []byte {
// ShiftState is called with StateCreated.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
serverAddr string, opts ...Option) (*Session, error) {
func (db *BoltStore) NewSession(ctx context.Context, label string, typ Type,
expiry time.Time, serverAddr string, opts ...Option) (*Session, error) {

var session *Session
err := db.Update(func(tx *bbolt.Tx) error {
Expand Down Expand Up @@ -285,7 +286,7 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
// to the session with the given local pub key.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
func (db *BoltStore) UpdateSessionRemotePubKey(_ context.Context, localPubKey,
remotePubKey *btcec.PublicKey) error {

key := localPubKey.SerializeCompressed()
Expand Down Expand Up @@ -318,7 +319,9 @@ func (db *BoltStore) UpdateSessionRemotePubKey(localPubKey,
// GetSession fetches the session with the given key.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
func (db *BoltStore) GetSession(_ context.Context, key *btcec.PublicKey) (
*Session, error) {

var session *Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -348,7 +351,7 @@ func (db *BoltStore) GetSession(key *btcec.PublicKey) (*Session, error) {
// ListAllSessions returns all sessions currently known to the store.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListAllSessions() ([]*Session, error) {
func (db *BoltStore) ListAllSessions(_ context.Context) ([]*Session, error) {
return db.listSessions(func(s *Session) bool {
return true
})
Expand All @@ -358,7 +361,9 @@ func (db *BoltStore) ListAllSessions() ([]*Session, error) {
// have the given type.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
func (db *BoltStore) ListSessionsByType(_ context.Context, t Type) ([]*Session,
error) {

return db.listSessions(func(s *Session) bool {
return s.Type == t
})
Expand All @@ -368,7 +373,9 @@ func (db *BoltStore) ListSessionsByType(t Type) ([]*Session, error) {
// are in the given states.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ListSessionsByState(states ...State) ([]*Session, error) {
func (db *BoltStore) ListSessionsByState(_ context.Context, states ...State) (
[]*Session, error) {

return db.listSessions(func(s *Session) bool {
for _, state := range states {
if s.State == state {
Expand Down Expand Up @@ -429,7 +436,7 @@ func (db *BoltStore) listSessions(filterFn func(s *Session) bool) ([]*Session,
// state.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) DeleteReservedSessions() error {
func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
Expand Down Expand Up @@ -522,7 +529,7 @@ func (db *BoltStore) DeleteReservedSessions() error {
// state.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) ShiftState(id ID, dest State) error {
func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error {
return db.Update(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
if err != nil {
Expand Down Expand Up @@ -562,7 +569,9 @@ func (db *BoltStore) ShiftState(id ID, dest State) error {
// GetSessionByID fetches the session with the given ID.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) GetSessionByID(id ID) (*Session, error) {
func (db *BoltStore) GetSessionByID(_ context.Context, id ID) (*Session,
error) {

var session *Session
err := db.View(func(tx *bbolt.Tx) error {
sessionBucket, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -615,7 +624,7 @@ func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
// GetGroupID will return the group ID for the given session ID.
//
// NOTE: this is part of the IDToGroupIndex interface.
func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) {
var groupID ID
err := db.View(func(tx *bbolt.Tx) error {
sessionBkt, err := getBucket(tx, sessionBucketKey)
Expand Down Expand Up @@ -655,7 +664,9 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
// group with the given ID.
//
// NOTE: this is part of the IDToGroupIndex interface.
func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
func (db *BoltStore) GetSessionIDs(_ context.Context, groupID ID) ([]ID,
error) {

var (
sessionIDs []ID
err error
Expand Down
Loading
Loading