Skip to content

Commit 85b049b

Browse files
committed
session: remove Session Group Predicate method
This was used to check that all linked sessions are no longer active before attempting to register an autopilot session. But this is no longer needed since this is done within NewSession.
1 parent 4c87858 commit 85b049b

File tree

4 files changed

+2
-175
lines changed

4 files changed

+2
-175
lines changed

session/interface.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,6 @@ type Store interface {
203203
// GetSessionByID fetches the session with the given ID.
204204
GetSessionByID(id ID) (*Session, error)
205205

206-
// CheckSessionGroupPredicate iterates over all the sessions in a group
207-
// and checks if each one passes the given predicate function. True is
208-
// returned if each session passes.
209-
CheckSessionGroupPredicate(groupID ID,
210-
fn func(s *Session) bool) (bool, error)
211-
212206
// DeleteReservedSessions deletes all sessions that are in the
213207
// StateReserved state.
214208
DeleteReservedSessions() error

session/kvdb_store.go

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ func (db *BoltStore) NewSession(label string, typ Type, expiry time.Time,
247247

248248
// Ensure that the session is no longer active.
249249
if sess.State == StateCreated ||
250-
sess.State == StateInUse {
250+
sess.State == StateInUse ||
251+
sess.State == StateReserved {
251252

252253
return fmt.Errorf("session (id=%x) "+
253254
"in group %x is still active",
@@ -699,65 +700,6 @@ func (db *BoltStore) GetSessionIDs(groupID ID) ([]ID, error) {
699700
return sessionIDs, nil
700701
}
701702

702-
// CheckSessionGroupPredicate iterates over all the sessions in a group and
703-
// checks if each one passes the given predicate function. True is returned if
704-
// each session passes.
705-
//
706-
// NOTE: this is part of the Store interface.
707-
func (db *BoltStore) CheckSessionGroupPredicate(groupID ID,
708-
fn func(s *Session) bool) (bool, error) {
709-
710-
var (
711-
pass bool
712-
errFailedPred = errors.New("session failed predicate")
713-
)
714-
err := db.View(func(tx *bbolt.Tx) error {
715-
sessionBkt, err := getBucket(tx, sessionBucketKey)
716-
if err != nil {
717-
return err
718-
}
719-
720-
sessionIDs, err := getSessionIDs(sessionBkt, groupID)
721-
if err != nil {
722-
return err
723-
}
724-
725-
// Iterate over all the sessions.
726-
for _, id := range sessionIDs {
727-
key, err := getKeyForID(sessionBkt, id)
728-
if err != nil {
729-
return err
730-
}
731-
732-
v := sessionBkt.Get(key)
733-
if len(v) == 0 {
734-
return ErrSessionNotFound
735-
}
736-
737-
session, err := DeserializeSession(bytes.NewReader(v))
738-
if err != nil {
739-
return err
740-
}
741-
742-
if !fn(session) {
743-
return errFailedPred
744-
}
745-
}
746-
747-
pass = true
748-
749-
return nil
750-
})
751-
if errors.Is(err, errFailedPred) {
752-
return pass, nil
753-
}
754-
if err != nil {
755-
return pass, err
756-
}
757-
758-
return pass, nil
759-
}
760-
761703
// getSessionIDs returns all the session IDs associated with the given group ID.
762704
func getSessionIDs(sessionBkt *bbolt.Bucket, groupID ID) ([]ID, error) {
763705
var sessionIDs []ID

session/store_test.go

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package session
22

33
import (
4-
"strings"
54
"testing"
65
"time"
76

@@ -288,97 +287,6 @@ func TestLinkedSessions(t *testing.T) {
288287
require.EqualValues(t, []ID{s4.ID, s5.ID}, sIDs)
289288
}
290289

291-
// TestCheckSessionGroupPredicate asserts that the CheckSessionGroupPredicate
292-
// method correctly checks if each session in a group passes a predicate.
293-
func TestCheckSessionGroupPredicate(t *testing.T) {
294-
t.Parallel()
295-
296-
// Set up a new DB.
297-
clock := clock.NewTestClock(testTime)
298-
db, err := NewDB(t.TempDir(), "test.db", clock)
299-
require.NoError(t, err)
300-
t.Cleanup(func() {
301-
_ = db.Close()
302-
})
303-
304-
// We will use the Label of the Session to test that the predicate
305-
// function is checked correctly.
306-
307-
// Add a new session to the DB.
308-
s1 := createSession(t, db, "label 1")
309-
310-
// Check that the group passes against an appropriate predicate.
311-
ok, err := db.CheckSessionGroupPredicate(
312-
s1.GroupID, func(s *Session) bool {
313-
return strings.Contains(s.Label, "label 1")
314-
},
315-
)
316-
require.NoError(t, err)
317-
require.True(t, ok)
318-
319-
// Check that the group fails against an appropriate predicate.
320-
ok, err = db.CheckSessionGroupPredicate(
321-
s1.GroupID, func(s *Session) bool {
322-
return strings.Contains(s.Label, "label 2")
323-
},
324-
)
325-
require.NoError(t, err)
326-
require.False(t, ok)
327-
328-
// Revoke the first session.
329-
require.NoError(t, db.RevokeSession(s1.LocalPublicKey))
330-
331-
// Add a new session to the same group as the first one.
332-
_ = createSession(t, db, "label 2", withLinkedGroupID(&s1.GroupID))
333-
334-
// Check that the group passes against an appropriate predicate.
335-
ok, err = db.CheckSessionGroupPredicate(
336-
s1.GroupID, func(s *Session) bool {
337-
return strings.Contains(s.Label, "label")
338-
},
339-
)
340-
require.NoError(t, err)
341-
require.True(t, ok)
342-
343-
// Check that the group fails against an appropriate predicate.
344-
ok, err = db.CheckSessionGroupPredicate(
345-
s1.GroupID, func(s *Session) bool {
346-
return strings.Contains(s.Label, "label 1")
347-
},
348-
)
349-
require.NoError(t, err)
350-
require.False(t, ok)
351-
352-
// Add a new session that is not linked to the first one.
353-
s3 := createSession(t, db, "completely different")
354-
355-
// Ensure that the first group is unaffected.
356-
ok, err = db.CheckSessionGroupPredicate(
357-
s1.GroupID, func(s *Session) bool {
358-
return strings.Contains(s.Label, "label")
359-
},
360-
)
361-
require.NoError(t, err)
362-
require.True(t, ok)
363-
364-
// And that the new session is evaluated separately.
365-
ok, err = db.CheckSessionGroupPredicate(
366-
s3.GroupID, func(s *Session) bool {
367-
return strings.Contains(s.Label, "label")
368-
},
369-
)
370-
require.NoError(t, err)
371-
require.False(t, ok)
372-
373-
ok, err = db.CheckSessionGroupPredicate(
374-
s3.GroupID, func(s *Session) bool {
375-
return strings.Contains(s.Label, "different")
376-
},
377-
)
378-
require.NoError(t, err)
379-
require.True(t, ok)
380-
}
381-
382290
type testSessionOpts struct {
383291
groupID *ID
384292
sessType Type

session_rpcserver.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -857,23 +857,6 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
857857
"group %x", groupSess.ID, groupSess.GroupID)
858858
}
859859

860-
// Now we need to check that all the sessions in the group are
861-
// no longer active.
862-
ok, err := s.cfg.db.CheckSessionGroupPredicate(
863-
groupID, func(s *session.Session) bool {
864-
return s.State == session.StateRevoked ||
865-
s.State == session.StateExpired
866-
},
867-
)
868-
if err != nil {
869-
return nil, err
870-
}
871-
872-
if !ok {
873-
return nil, fmt.Errorf("a linked session in group "+
874-
"%x is still active", groupID)
875-
}
876-
877860
linkedGroupID = &groupID
878861
linkedGroupSession = groupSess
879862

0 commit comments

Comments
 (0)