Skip to content

Commit

Permalink
session: remove Session Group Predicate method
Browse files Browse the repository at this point in the history
This was used to check that all linked sessions are no longer
active before attempting to register an autopilot session. But this is
no longer needed since this is done within NewSession.
  • Loading branch information
ellemouton committed Feb 12, 2025
1 parent d65d04c commit 001da85
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 175 deletions.
6 changes: 0 additions & 6 deletions session/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,6 @@ type Store interface {
// GetSessionByID fetches the session with the given ID.
GetSessionByID(id ID) (*Session, error)

// CheckSessionGroupPredicate iterates over all the sessions in a group
// and checks if each one passes the given predicate function. True is
// returned if each session passes.
CheckSessionGroupPredicate(groupID ID,
fn func(s *Session) bool) (bool, error)

// DeleteReservedSessions deletes all sessions that are in the
// StateReserved state.
DeleteReservedSessions() error
Expand Down
62 changes: 2 additions & 60 deletions session/kvdb_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,

// Ensure that the session is no longer active.
if sess.State == StateCreated ||
sess.State == StateInUse {
sess.State == StateInUse ||
sess.State == StateReserved {

return fmt.Errorf("session (id=%x) "+
"in group %x is still active",
Expand Down Expand Up @@ -692,65 +693,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
return sessionIDs, nil
}

// CheckSessionGroupPredicate iterates over all the sessions in a group and
// checks if each one passes the given predicate function. True is returned if
// each session passes.
//
// NOTE: this is part of the Store interface.
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
fn func(s *Session) bool) (bool, error) {

var (
pass bool
errFailedPred = errors.New("session failed predicate")
)
err := db.View(func(tx *bbolt.Tx) error {
sessionBkt, err := getBucket(tx, sessionBucketKey)
if err != nil {
return err
}

sessionIDs, err := getSessionIDs(sessionBkt, groupID)
if err != nil {
return err
}

// Iterate over all the sessions.
for _, id := range sessionIDs {
key, err := getKeyForID(sessionBkt, id)
if err != nil {
return err
}

v := sessionBkt.Get(key)
if len(v) == 0 {
return ErrSessionNotFound
}

session, err := DeserializeSession(bytes.NewReader(v))
if err != nil {
return err
}

if !fn(session) {
return errFailedPred
}
}

pass = true

return nil
})
if errors.Is(err, errFailedPred) {
return pass, nil
}
if err != nil {
return pass, err
}

return pass, nil
}

// getSessionIDs returns all the session IDs associated with the given group ID.
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
var sessionIDs []ID
Expand Down
92 changes: 0 additions & 92 deletions session/store_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package session

import (
"strings"
"testing"
"time"

Expand Down Expand Up @@ -283,97 +282,6 @@ func TestLinkedSessions(t *testing.T) {
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
}

// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
// method correctly checks if each session in a group passes a predicate.
func TestCheckSessionGroupPredicate(t *testing.T) {
t.Parallel()

// Set up a new DB.
clock := clock.NewTestClock(testTime)
db, err := NewDB(t.TempDir(), "test.db", clock)
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})

// We will use the Label of the Session to test that the predicate
// function is checked correctly.

// Add a new session to the DB.
s1 := createSession(t, db, "label 1")

// Check that the group passes against an appropriate predicate.
ok, err := db.CheckSessionGroupPredicate(
s1.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "label 1")
},
)
require.NoError(t, err)
require.True(t, ok)

// Check that the group fails against an appropriate predicate.
ok, err = db.CheckSessionGroupPredicate(
s1.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "label 2")
},
)
require.NoError(t, err)
require.False(t, ok)

// Revoke the first session.
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))

// Add a new session to the same group as the first one.
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))

// Check that the group passes against an appropriate predicate.
ok, err = db.CheckSessionGroupPredicate(
s1.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "label")
},
)
require.NoError(t, err)
require.True(t, ok)

// Check that the group fails against an appropriate predicate.
ok, err = db.CheckSessionGroupPredicate(
s1.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "label 1")
},
)
require.NoError(t, err)
require.False(t, ok)

// Add a new session that is not linked to the first one.
s3 := createSession(t, db, "completely different")

// Ensure that the first group is unaffected.
ok, err = db.CheckSessionGroupPredicate(
s1.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "label")
},
)
require.NoError(t, err)
require.True(t, ok)

// And that the new session is evaluated separately.
ok, err = db.CheckSessionGroupPredicate(
s3.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "label")
},
)
require.NoError(t, err)
require.False(t, ok)

ok, err = db.CheckSessionGroupPredicate(
s3.GroupID, func(s *Session) bool {
return strings.Contains(s.Label, "different")
},
)
require.NoError(t, err)
require.True(t, ok)
}

type testSessionOpts struct {
groupID *ID
sessType Type
Expand Down
17 changes: 0 additions & 17 deletions session_rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -858,23 +858,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
"group %x", groupSess.ID, groupSess.GroupID)
}

// Now we need to check that all the sessions in the group are
// no longer active.
ok, err := s.cfg.db.CheckSessionGroupPredicate(
groupID, func(s *session.Session) bool {
return s.State == session.StateRevoked ||
s.State == session.StateExpired
},
)
if err != nil {
return nil, err
}

if !ok {
return nil, fmt.Errorf("a linked session in group "+
"%x is still active", groupID)
}

linkedGroupID = &groupID
linkedGroupSession = groupSess

Expand Down

0 comments on commit 001da85

Please sign in to comment.