Skip to content

Commit 81dd9fc

Browse files
committed
session: add context to ShiftState
1 parent c468beb commit 81dd9fc

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

session/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ type Store interface {
225225

226226
// ShiftState updates the state of the session with the given ID to the
227227
// "dest" state.
228-
ShiftState(id ID, dest State) error
228+
ShiftState(ctx context.Context, id ID, dest State) error
229229

230230
IDToGroupIndex
231231
}

session/kvdb_store.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
530530
// state.
531531
//
532532
// NOTE: this is part of the Store interface.
533-
func (db *BoltStore) ShiftState(id ID, dest State) error {
533+
func (db *BoltStore) ShiftState(_ context.Context, id ID, dest State) error {
534534
return db.Update(func(tx *bbolt.Tx) error {
535535
sessionBucket, err := getBucket(tx, sessionBucketKey)
536536
if err != nil {

session/store_test.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestBasicSessionStore(t *testing.T) {
3636
require.Equal(t, StateReserved, s1.State)
3737

3838
// Move session 1 to the created state. This should succeed.
39-
err = db.ShiftState(s1.ID, StateCreated)
39+
err = db.ShiftState(ctx, s1.ID, StateCreated)
4040
require.NoError(t, err)
4141

4242
// Show that the session is now in the created state.
@@ -46,7 +46,7 @@ func TestBasicSessionStore(t *testing.T) {
4646

4747
// Trying to move session 1 again should have no effect since it is
4848
// already in the created state.
49-
require.NoError(t, db.ShiftState(s1.ID, StateCreated))
49+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateCreated))
5050

5151
// Reserve and create a few more sessions. We increment the time by one
5252
// second between each session to ensure that the created at time is
@@ -107,7 +107,7 @@ func TestBasicSessionStore(t *testing.T) {
107107
require.Equal(t, session1.State, StateCreated)
108108

109109
// Now revoke the session and assert that the state is revoked.
110-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
110+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
111111
s1, err = db.GetSession(ctx, s1.LocalPublicKey)
112112
require.NoError(t, err)
113113
require.Equal(t, s1.State, StateRevoked)
@@ -198,6 +198,7 @@ func TestBasicSessionStore(t *testing.T) {
198198
// TestLinkingSessions tests that session linking works as expected.
199199
func TestLinkingSessions(t *testing.T) {
200200
t.Parallel()
201+
ctx := context.Background()
201202

202203
// Set up a new DB.
203204
clock := clock.NewTestClock(testTime)
@@ -227,7 +228,7 @@ func TestLinkingSessions(t *testing.T) {
227228
require.ErrorContains(t, err, "is still active")
228229

229230
// Revoke the first session.
230-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
231+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
231232

232233
// Persisting the second linked session should now work.
233234
_, err = reserveSession(db, "session 2", withLinkedGroupID(&s1.GroupID))
@@ -255,10 +256,10 @@ func TestLinkedSessions(t *testing.T) {
255256
// first session.
256257
s1 := createSession(t, db, "session 1")
257258

258-
require.NoError(t, db.ShiftState(s1.ID, StateRevoked))
259+
require.NoError(t, db.ShiftState(ctx, s1.ID, StateRevoked))
259260
s2 := createSession(t, db, "session 2", withLinkedGroupID(&s1.GroupID))
260261

261-
require.NoError(t, db.ShiftState(s2.ID, StateRevoked))
262+
require.NoError(t, db.ShiftState(ctx, s2.ID, StateRevoked))
262263
s3 := createSession(t, db, "session 3", withLinkedGroupID(&s2.GroupID))
263264

264265
// Assert that the session ID to group ID index works as expected.
@@ -277,7 +278,7 @@ func TestLinkedSessions(t *testing.T) {
277278
// To ensure that different groups don't interfere with each other,
278279
// let's add another set of linked sessions not linked to the first.
279280
s4 := createSession(t, db, "session 4")
280-
require.NoError(t, db.ShiftState(s4.ID, StateRevoked))
281+
require.NoError(t, db.ShiftState(ctx, s4.ID, StateRevoked))
281282
s5 := createSession(t, db, "session 5", withLinkedGroupID(&s4.GroupID))
282283
require.NotEqual(t, s4.GroupID, s1.GroupID)
283284

@@ -319,7 +320,7 @@ func TestStateShift(t *testing.T) {
319320
require.Equal(t, time.Time{}, s1.RevokedAt)
320321

321322
// Shift the state of the session to StateRevoked.
322-
err = db.ShiftState(s1.ID, StateRevoked)
323+
err = db.ShiftState(ctx, s1.ID, StateRevoked)
323324
require.NoError(t, err)
324325

325326
// This should have worked. Since it is now in a terminal state, the
@@ -334,13 +335,13 @@ func TestStateShift(t *testing.T) {
334335
// should not have changed though.
335336
prevTime := clock.Now()
336337
clock.SetTime(prevTime.Add(time.Second))
337-
err = db.ShiftState(s1.ID, StateRevoked)
338+
err = db.ShiftState(ctx, s1.ID, StateRevoked)
338339
require.NoError(t, err)
339340
require.True(t, prevTime.Equal(s1.RevokedAt))
340341

341342
// Trying to shift the state from a terminal state back to StateCreated
342343
// should also fail since this is not a legal state transition.
343-
err = db.ShiftState(s1.ID, StateCreated)
344+
err = db.ShiftState(ctx, s1.ID, StateCreated)
344345
require.ErrorContains(t, err, "illegal session state transition")
345346
}
346347

@@ -394,7 +395,7 @@ func createSession(t *testing.T, db Store, label string,
394395
s, err := reserveSession(db, label, mods...)
395396
require.NoError(t, err)
396397

397-
err = db.ShiftState(s.ID, StateCreated)
398+
err = db.ShiftState(context.Background(), s.ID, StateCreated)
398399
require.NoError(t, err)
399400

400401
s, err = db.GetSessionByID(context.Background(), s.ID)

session_rpcserver.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (s *sessionRpcServer) start(ctx context.Context) error {
149149

150150
if perm {
151151
err := s.cfg.db.ShiftState(
152-
sess.ID, session.StateRevoked,
152+
ctx, sess.ID, session.StateRevoked,
153153
)
154154
if err != nil {
155155
log.Errorf("error revoking "+
@@ -316,7 +316,7 @@ func (s *sessionRpcServer) AddSession(ctx context.Context,
316316
return nil, fmt.Errorf("error creating new session: %v", err)
317317
}
318318

319-
err = s.cfg.db.ShiftState(sess.ID, session.StateCreated)
319+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateCreated)
320320
if err != nil {
321321
return nil, fmt.Errorf("error shifting session state to "+
322322
"Created: %v", err)
@@ -355,7 +355,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
355355
log.Debugf("Not resuming session %x with expiry %s",
356356
pubKeyBytes, sess.Expiry)
357357

358-
err := s.cfg.db.ShiftState(sess.ID, session.StateExpired)
358+
err := s.cfg.db.ShiftState(ctx, sess.ID, session.StateExpired)
359359
if err != nil {
360360
return fmt.Errorf("error revoking session: %v", err)
361361
}
@@ -433,7 +433,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
433433
"passed. Revoking session", pubKeyBytes)
434434

435435
return s.cfg.db.ShiftState(
436-
sess.ID, session.StateRevoked,
436+
ctx, sess.ID, session.StateRevoked,
437437
)
438438
}
439439

@@ -513,7 +513,7 @@ func (s *sessionRpcServer) resumeSession(ctx context.Context,
513513
log.Debugf("Error stopping session: %v", err)
514514
}
515515

516-
err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
516+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateRevoked)
517517
if err != nil {
518518
log.Debugf("error revoking session: %v", err)
519519
}
@@ -560,7 +560,7 @@ func (s *sessionRpcServer) RevokeSession(ctx context.Context,
560560
return nil, fmt.Errorf("error fetching session: %v", err)
561561
}
562562

563-
err = s.cfg.db.ShiftState(sess.ID, session.StateRevoked)
563+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateRevoked)
564564
if err != nil {
565565
return nil, fmt.Errorf("error revoking session: %v", err)
566566
}
@@ -1213,7 +1213,7 @@ func (s *sessionRpcServer) AddAutopilotSession(ctx context.Context,
12131213

12141214
// We only activate the session if the Autopilot server registration
12151215
// was successful.
1216-
err = s.cfg.db.ShiftState(sess.ID, session.StateCreated)
1216+
err = s.cfg.db.ShiftState(ctx, sess.ID, session.StateCreated)
12171217
if err != nil {
12181218
return nil, fmt.Errorf("error shifting session state to "+
12191219
"Created: %v", err)

0 commit comments

Comments
 (0)