Skip to content

Commit c468beb

Browse files
committed
session: pass contexts through to all IDToGroupIndex methods
1 parent 0bffa6b commit c468beb

File tree

6 files changed

+25
-18
lines changed

6 files changed

+25
-18
lines changed

firewalldb/actions.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID,
391391
// pass the filterFn requirements.
392392
//
393393
// TODO: update to allow for pagination.
394-
func (db *DB) ListGroupActions(_ context.Context, groupID session.ID,
394+
func (db *DB) ListGroupActions(ctx context.Context, groupID session.ID,
395395
filterFn ListActionsFilterFn) ([]*Action, error) {
396396

397397
if filterFn == nil {
@@ -400,7 +400,7 @@ func (db *DB) ListGroupActions(_ context.Context, groupID session.ID,
400400
}
401401
}
402402

403-
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(groupID)
403+
sessionIDs, err := db.sessionIDIndex.GetSessionIDs(ctx, groupID)
404404
if err != nil {
405405
return nil, err
406406
}

firewalldb/mock.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ func (m *mockSessionDB) AddPair(sessionID, groupID session.ID) {
3434
}
3535

3636
// GetGroupID returns the group ID for the given session ID.
37-
func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
37+
func (m *mockSessionDB) GetGroupID(_ context.Context, sessionID session.ID) (
38+
session.ID, error) {
39+
3840
id, ok := m.sessionToGroupID[sessionID]
3941
if !ok {
4042
return session.ID{}, fmt.Errorf("no group ID found for " +
@@ -45,7 +47,9 @@ func (m *mockSessionDB) GetGroupID(sessionID session.ID) (session.ID, error) {
4547
}
4648

4749
// GetSessionIDs returns the set of session IDs that are in the group
48-
func (m *mockSessionDB) GetSessionIDs(groupID session.ID) ([]session.ID, error) {
50+
func (m *mockSessionDB) GetSessionIDs(_ context.Context, groupID session.ID) (
51+
[]session.ID, error) {
52+
4953
ids, ok := m.groupToSessionIDs[groupID]
5054
if !ok {
5155
return nil, fmt.Errorf("no session IDs found for group ID")

session/interface.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,11 @@ func buildSession(id ID, localPrivKey *btcec.PrivateKey, label string, typ Type,
178178
// IDToGroupIndex defines an interface for the session ID to group ID index.
179179
type IDToGroupIndex interface {
180180
// GetGroupID will return the group ID for the given session ID.
181-
GetGroupID(sessionID ID) (ID, error)
181+
GetGroupID(ctx context.Context, sessionID ID) (ID, error)
182182

183183
// GetSessionIDs will return the set of session IDs that are in the
184184
// group with the given ID.
185-
GetSessionIDs(groupID ID) ([]ID, error)
185+
GetSessionIDs(ctx context.Context, groupID ID) ([]ID, error)
186186
}
187187

188188
// Store is the interface a persistent storage must implement for storing and

session/kvdb_store.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ func getUnusedIDAndKeyPair(bucket *bbolt.Bucket) (ID, *btcec.PrivateKey,
625625
// GetGroupID will return the group ID for the given session ID.
626626
//
627627
// NOTE: this is part of the IDToGroupIndex interface.
628-
func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
628+
func (db *BoltStore) GetGroupID(_ context.Context, sessionID ID) (ID, error) {
629629
var groupID ID
630630
err := db.View(func(tx *bbolt.Tx) error {
631631
sessionBkt, err := getBucket(tx, sessionBucketKey)
@@ -665,7 +665,9 @@ func (db *BoltStore) GetGroupID(sessionID ID) (ID, error) {
665665
// group with the given ID.
666666
//
667667
// NOTE: this is part of the IDToGroupIndex interface.
668-
func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
668+
func (db *BoltStore) GetSessionIDs(_ context.Context, groupID ID) ([]ID,
669+
error) {
670+
669671
var (
670672
sessionIDs []ID
671673
err error

session/store_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ func TestBasicSessionStore(t *testing.T) {
170170

171171
// Show that the group ID/session ID index has also been populated with
172172
// this session.
173-
groupID, err := db.GetGroupID(s4.ID)
173+
groupID, err := db.GetGroupID(ctx, s4.ID)
174174
require.NoError(t, err)
175175
require.Equal(t, s1.ID, groupID)
176176

177-
sessIDs, err := db.GetSessionIDs(s4.GroupID)
177+
sessIDs, err := db.GetSessionIDs(ctx, s4.GroupID)
178178
require.NoError(t, err)
179179
require.ElementsMatch(t, []ID{s4.ID, s1.ID}, sessIDs)
180180

@@ -186,11 +186,11 @@ func TestBasicSessionStore(t *testing.T) {
186186
require.NoError(t, err)
187187
require.Empty(t, sessions)
188188

189-
_, err = db.GetGroupID(s4.ID)
189+
_, err = db.GetGroupID(ctx, s4.ID)
190190
require.ErrorContains(t, err, "no index entry")
191191

192192
// Only session 1 should remain in this group.
193-
sessIDs, err = db.GetSessionIDs(s4.GroupID)
193+
sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID)
194194
require.NoError(t, err)
195195
require.ElementsMatch(t, []ID{s1.ID}, sessIDs)
196196
}
@@ -239,6 +239,7 @@ func TestLinkingSessions(t *testing.T) {
239239
// of the GetGroupID and GetSessionIDs methods.
240240
func TestLinkedSessions(t *testing.T) {
241241
t.Parallel()
242+
ctx := context.Background()
242243

243244
// Set up a new DB.
244245
clock := clock.NewTestClock(testTime)
@@ -262,14 +263,14 @@ func TestLinkedSessions(t *testing.T) {
262263

263264
// Assert that the session ID to group ID index works as expected.
264265
for _, s := range []*Session{s1, s2, s3} {
265-
groupID, err := db.GetGroupID(s.ID)
266+
groupID, err := db.GetGroupID(ctx, s.ID)
266267
require.NoError(t, err)
267268
require.Equal(t, s1.ID, groupID)
268269
require.Equal(t, s.GroupID, groupID)
269270
}
270271

271272
// Assert that the group ID to session ID index works as expected.
272-
sIDs, err := db.GetSessionIDs(s1.GroupID)
273+
sIDs, err := db.GetSessionIDs(ctx, s1.GroupID)
273274
require.NoError(t, err)
274275
require.EqualValues(t, []ID{s1.ID, s2.ID, s3.ID}, sIDs)
275276

@@ -282,14 +283,14 @@ func TestLinkedSessions(t *testing.T) {
282283

283284
// Assert that the session ID to group ID index works as expected.
284285
for _, s := range []*Session{s4, s5} {
285-
groupID, err := db.GetGroupID(s.ID)
286+
groupID, err := db.GetGroupID(ctx, s.ID)
286287
require.NoError(t, err)
287288
require.Equal(t, s4.ID, groupID)
288289
require.Equal(t, s.GroupID, groupID)
289290
}
290291

291292
// Assert that the group ID to session ID index works as expected.
292-
sIDs, err = db.GetSessionIDs(s5.GroupID)
293+
sIDs, err = db.GetSessionIDs(ctx, s5.GroupID)
293294
require.NoError(t, err)
294295
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
295296
}

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
580580

581581
// PrivacyMapConversion can be used map real values to their pseudo counterpart
582582
// and vice versa.
583-
func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context,
583+
func (s *sessionRpcServer) PrivacyMapConversion(ctx context.Context,
584584
req *litrpc.PrivacyMapConversionRequest) (
585585
*litrpc.PrivacyMapConversionResponse, error) {
586586

@@ -599,7 +599,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context,
599599
return nil, err
600600
}
601601

602-
groupID, err = s.cfg.db.GetGroupID(sessionID)
602+
groupID, err = s.cfg.db.GetGroupID(ctx, sessionID)
603603
if err != nil {
604604
return nil, err
605605
}

0 commit comments

Comments
 (0)