Skip to content

Commit 0bffa6b

Browse files
committed
firewalldb: let ListGroupActions take a context
In preparation for it needing to pass one to GetSessionIDs.
1 parent 5fc13f9 commit 0bffa6b

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

firewalldb/actions.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ func (db *DB) ListSessionActions(sessionID session.ID,
391391
// pass the filterFn requirements.
392392
//
393393
// TODO: update to allow for pagination.
394-
func (db *DB) ListGroupActions(groupID session.ID,
394+
func (db *DB) ListGroupActions(_ context.Context, groupID session.ID,
395395
filterFn ListActionsFilterFn) ([]*Action, error) {
396396

397397
if filterFn == nil {
@@ -629,11 +629,11 @@ type groupActionsReadDB struct {
629629
var _ ActionsDB = (*groupActionsReadDB)(nil)
630630

631631
// ListActions will return all the Actions for a particular group.
632-
func (s *groupActionsReadDB) ListActions(_ context.Context) ([]*RuleAction,
632+
func (s *groupActionsReadDB) ListActions(ctx context.Context) ([]*RuleAction,
633633
error) {
634634

635635
sessionActions, err := s.db.ListGroupActions(
636-
s.groupID, func(a *Action, _ bool) (bool, bool) {
636+
ctx, s.groupID, func(a *Action, _ bool) (bool, bool) {
637637
return a.State == ActionStateDone, true
638638
},
639639
)
@@ -660,11 +660,11 @@ var _ ActionsDB = (*groupFeatureActionsReadDB)(nil)
660660

661661
// ListActions will return all the Actions for a particular group that were
662662
// executed by a particular feature.
663-
func (a *groupFeatureActionsReadDB) ListActions(_ context.Context) (
663+
func (a *groupFeatureActionsReadDB) ListActions(ctx context.Context) (
664664
[]*RuleAction, error) {
665665

666666
featureActions, err := a.db.ListGroupActions(
667-
a.groupID, func(action *Action, _ bool) (bool, bool) {
667+
ctx, a.groupID, func(action *Action, _ bool) (bool, bool) {
668668
return action.State == ActionStateDone &&
669669
action.FeatureName == a.featureName, true
670670
},

firewalldb/actions_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package firewalldb
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67
"time"
@@ -342,6 +343,9 @@ func TestListActions(t *testing.T) {
342343
// TestListGroupActions tests that the ListGroupActions correctly returns all
343344
// actions in a particular session group.
344345
func TestListGroupActions(t *testing.T) {
346+
t.Parallel()
347+
ctx := context.Background()
348+
345349
group1 := intToSessionID(0)
346350

347351
// Link session 1 and session 2 to group 1.
@@ -356,7 +360,7 @@ func TestListGroupActions(t *testing.T) {
356360
})
357361

358362
// There should not be any actions in group 1 yet.
359-
al, err := db.ListGroupActions(group1, nil)
363+
al, err := db.ListGroupActions(ctx, group1, nil)
360364
require.NoError(t, err)
361365
require.Empty(t, al)
362366

@@ -365,7 +369,7 @@ func TestListGroupActions(t *testing.T) {
365369
require.NoError(t, err)
366370

367371
// There should now be one action in the group.
368-
al, err = db.ListGroupActions(group1, nil)
372+
al, err = db.ListGroupActions(ctx, group1, nil)
369373
require.NoError(t, err)
370374
require.Len(t, al, 1)
371375
require.Equal(t, sessionID1, al[0].SessionID)
@@ -375,7 +379,7 @@ func TestListGroupActions(t *testing.T) {
375379
require.NoError(t, err)
376380

377381
// There should now be actions in the group.
378-
al, err = db.ListGroupActions(group1, nil)
382+
al, err = db.ListGroupActions(ctx, group1, nil)
379383
require.NoError(t, err)
380384
require.Len(t, al, 2)
381385
require.Equal(t, sessionID1, al[0].SessionID)

session_rpcserver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ func (s *sessionRpcServer) PrivacyMapConversion(_ context.Context,
636636
// stored if the actions are interceptor actions, otherwise only the URI and
637637
// timestamp of the actions will be stored. The "full" mode will persist all
638638
// request data for all actions.
639-
func (s *sessionRpcServer) ListActions(_ context.Context,
639+
func (s *sessionRpcServer) ListActions(ctx context.Context,
640640
req *litrpc.ListActionsRequest) (*litrpc.ListActionsResponse, error) {
641641

642642
// If no maximum number of actions is given, use a default of 100.
@@ -739,7 +739,7 @@ func (s *sessionRpcServer) ListActions(_ context.Context,
739739
return nil, err
740740
}
741741

742-
actions, err = db.ListGroupActions(groupID, filterFn)
742+
actions, err = db.ListGroupActions(ctx, groupID, filterFn)
743743
if err != nil {
744744
return nil, err
745745
}

0 commit comments

Comments
 (0)