Skip to content

Commit 001da85

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent d65d04c commit 001da85

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

session/interface.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,6 @@ type Store interface {
201201
// GetSessionByID fetches the session with the given ID.
202202
GetSessionByID(id ID) (*Session, error)
203203

204-
// CheckSessionGroupPredicate iterates over all the sessions in a group
205-
// and checks if each one passes the given predicate function. True is
206-
// returned if each session passes.
207-
CheckSessionGroupPredicate(groupID ID,
208-
fn func(s *Session) bool) (bool, error)
209-
210204
// DeleteReservedSessions deletes all sessions that are in the
211205
// StateReserved state.
212206
DeleteReservedSessions() error

session/kvdb_store.go

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
247247

248248
// Ensure that the session is no longer active.
249249
if sess.State == StateCreated ||
250-
sess.State == StateInUse {
250+
sess.State == StateInUse ||
251+
sess.State == StateReserved {
251252

252253
return fmt.Errorf("session (id=%x) "+
253254
"in group %x is still active",
@@ -692,65 +693,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
692693
return sessionIDs, nil
693694
}
694695

695-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
696-
// checks if each one passes the given predicate function. True is returned if
697-
// each session passes.
698-
//
699-
// NOTE: this is part of the Store interface.
700-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
701-
fn func(s *Session) bool) (bool, error) {
702-
703-
var (
704-
pass bool
705-
errFailedPred = errors.New("session failed predicate")
706-
)
707-
err := db.View(func(tx *bbolt.Tx) error {
708-
sessionBkt, err := getBucket(tx, sessionBucketKey)
709-
if err != nil {
710-
return err
711-
}
712-
713-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
714-
if err != nil {
715-
return err
716-
}
717-
718-
// Iterate over all the sessions.
719-
for _, id := range sessionIDs {
720-
key, err := getKeyForID(sessionBkt, id)
721-
if err != nil {
722-
return err
723-
}
724-
725-
v := sessionBkt.Get(key)
726-
if len(v) == 0 {
727-
return ErrSessionNotFound
728-
}
729-
730-
session, err := DeserializeSession(bytes.NewReader(v))
731-
if err != nil {
732-
return err
733-
}
734-
735-
if !fn(session) {
736-
return errFailedPred
737-
}
738-
}
739-
740-
pass = true
741-
742-
return nil
743-
})
744-
if errors.Is(err, errFailedPred) {
745-
return pass, nil
746-
}
747-
if err != nil {
748-
return pass, err
749-
}
750-
751-
return pass, nil
752-
}
753-
754696
// getSessionIDs returns all the session IDs associated with the given group ID.
755697
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
756698
var sessionIDs []ID

session/store_test.go

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -283,97 +282,6 @@ func TestLinkedSessions(t *testing.T) {
283282
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
284283
}
285284

286-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
287-
// method correctly checks if each session in a group passes a predicate.
288-
func TestCheckSessionGroupPredicate(t *testing.T) {
289-
t.Parallel()
290-
291-
// Set up a new DB.
292-
clock := clock.NewTestClock(testTime)
293-
db, err := NewDB(t.TempDir(), "test.db", clock)
294-
require.NoError(t, err)
295-
t.Cleanup(func() {
296-
_ = db.Close()
297-
})
298-
299-
// We will use the Label of the Session to test that the predicate
300-
// function is checked correctly.
301-
302-
// Add a new session to the DB.
303-
s1 := createSession(t, db, "label 1")
304-
305-
// Check that the group passes against an appropriate predicate.
306-
ok, err := db.CheckSessionGroupPredicate(
307-
s1.GroupID, func(s *Session) bool {
308-
return strings.Contains(s.Label, "label 1")
309-
},
310-
)
311-
require.NoError(t, err)
312-
require.True(t, ok)
313-
314-
// Check that the group fails against an appropriate predicate.
315-
ok, err = db.CheckSessionGroupPredicate(
316-
s1.GroupID, func(s *Session) bool {
317-
return strings.Contains(s.Label, "label 2")
318-
},
319-
)
320-
require.NoError(t, err)
321-
require.False(t, ok)
322-
323-
// Revoke the first session.
324-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
325-
326-
// Add a new session to the same group as the first one.
327-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
328-
329-
// Check that the group passes against an appropriate predicate.
330-
ok, err = db.CheckSessionGroupPredicate(
331-
s1.GroupID, func(s *Session) bool {
332-
return strings.Contains(s.Label, "label")
333-
},
334-
)
335-
require.NoError(t, err)
336-
require.True(t, ok)
337-
338-
// Check that the group fails against an appropriate predicate.
339-
ok, err = db.CheckSessionGroupPredicate(
340-
s1.GroupID, func(s *Session) bool {
341-
return strings.Contains(s.Label, "label 1")
342-
},
343-
)
344-
require.NoError(t, err)
345-
require.False(t, ok)
346-
347-
// Add a new session that is not linked to the first one.
348-
s3 := createSession(t, db, "completely different")
349-
350-
// Ensure that the first group is unaffected.
351-
ok, err = db.CheckSessionGroupPredicate(
352-
s1.GroupID, func(s *Session) bool {
353-
return strings.Contains(s.Label, "label")
354-
},
355-
)
356-
require.NoError(t, err)
357-
require.True(t, ok)
358-
359-
// And that the new session is evaluated separately.
360-
ok, err = db.CheckSessionGroupPredicate(
361-
s3.GroupID, func(s *Session) bool {
362-
return strings.Contains(s.Label, "label")
363-
},
364-
)
365-
require.NoError(t, err)
366-
require.False(t, ok)
367-
368-
ok, err = db.CheckSessionGroupPredicate(
369-
s3.GroupID, func(s *Session) bool {
370-
return strings.Contains(s.Label, "different")
371-
},
372-
)
373-
require.NoError(t, err)
374-
require.True(t, ok)
375-
}
376-
377285
type testSessionOpts struct {
378286
groupID *ID
379287
sessType Type

session_rpcserver.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -858,23 +858,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
858858
"group %x", groupSess.ID, groupSess.GroupID)
859859
}
860860

861-
// Now we need to check that all the sessions in the group are
862-
// no longer active.
863-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
864-
groupID, func(s *session.Session) bool {
865-
return s.State == session.StateRevoked ||
866-
s.State == session.StateExpired
867-
},
868-
)
869-
if err != nil {
870-
return nil, err
871-
}
872-
873-
if !ok {
874-
return nil, fmt.Errorf("a linked session in group "+
875-
"%x is still active", groupID)
876-
}
877-
878861
linkedGroupID = &groupID
879862
linkedGroupSession = groupSess
880863

0 commit comments

Comments
 (0)