Skip to content

Commit b57779a

Browse files
committed
session: implement DeleteReservedSession
1 parent 94644c7 commit b57779a

File tree

4 files changed

+159
-50
lines changed

4 files changed

+159
-50
lines changed

session/interface.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ type Store interface {
325325
// StateReserved state.
326326
DeleteReservedSessions(ctx context.Context) error
327327

328+
// DeleteReservedSession deletes the session with the given ID if it is
329+
// in the StateReserved state.
330+
DeleteReservedSession(ctx context.Context, id ID) error
331+
328332
// ShiftState updates the state of the session with the given ID to the
329333
// "dest" state.
330334
ShiftState(ctx context.Context, id ID, dest State) error

session/kvdb_store.go

Lines changed: 104 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,10 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
442442
return err
443443
}
444444

445-
return sessionBucket.ForEach(func(k, v []byte) error {
445+
// We create a copy of the sessions to delete so that we are
446+
// not iterating and modifying the bucket at the same time.
447+
var sessionsToDelete []*Session
448+
err = sessionBucket.ForEach(func(k, v []byte) error {
446449
// We'll also get buckets here, skip those (identified
447450
// by nil value).
448451
if v == nil {
@@ -458,69 +461,120 @@ func (db *BoltStore) DeleteReservedSessions(_ context.Context) error {
458461
return nil
459462
}
460463

461-
err = sessionBucket.Delete(k)
462-
if err != nil {
463-
return err
464-
}
464+
sessionsToDelete = append(sessionsToDelete, session)
465465

466-
idIndexBkt := sessionBucket.Bucket(idIndexKey)
467-
if idIndexBkt == nil {
468-
return ErrDBInitErr
469-
}
466+
return nil
467+
})
468+
if err != nil {
469+
return err
470+
}
470471

471-
// Delete the entire session ID bucket.
472-
err = idIndexBkt.DeleteBucket(session.ID[:])
473-
if err != nil {
472+
for _, session := range sessionsToDelete {
473+
if err := deleteSession(sessionBucket,
474+
session); err != nil {
474475
return err
475476
}
477+
}
476478

477-
groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey)
478-
if groupIdIndexBkt == nil {
479-
return ErrDBInitErr
480-
}
479+
return nil
480+
})
481+
}
481482

482-
groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:])
483-
if groupBkt == nil {
484-
return ErrDBInitErr
485-
}
483+
// deleteSession deletes all the parts of a session from the database. This
484+
// assumes that the session has already been fetched from the db.
485+
func deleteSession(sessionBucket *bbolt.Bucket, session *Session) error {
486+
sessionKey := getSessionKey(session)
487+
err := sessionBucket.Delete(sessionKey)
488+
if err != nil {
489+
return err
490+
}
486491

487-
sessionIDsBkt := groupBkt.Bucket(sessionIDKey)
488-
if sessionIDsBkt == nil {
489-
return ErrDBInitErr
490-
}
492+
idIndexBkt := sessionBucket.Bucket(idIndexKey)
493+
if idIndexBkt == nil {
494+
return ErrDBInitErr
495+
}
491496

492-
var (
493-
seqKey []byte
494-
numSessions int
495-
)
496-
err = sessionIDsBkt.ForEach(func(k, v []byte) error {
497-
numSessions++
497+
// Delete the entire session ID bucket.
498+
err = idIndexBkt.DeleteBucket(session.ID[:])
499+
if err != nil {
500+
return err
501+
}
498502

499-
if !bytes.Equal(v, session.ID[:]) {
500-
return nil
501-
}
503+
groupIdIndexBkt := sessionBucket.Bucket(groupIDIndexKey)
504+
if groupIdIndexBkt == nil {
505+
return ErrDBInitErr
506+
}
502507

503-
seqKey = k
508+
groupBkt := groupIdIndexBkt.Bucket(session.GroupID[:])
509+
if groupBkt == nil {
510+
return ErrDBInitErr
511+
}
504512

505-
return nil
506-
})
507-
if err != nil {
508-
return err
509-
}
513+
sessionIDsBkt := groupBkt.Bucket(sessionIDKey)
514+
if sessionIDsBkt == nil {
515+
return ErrDBInitErr
516+
}
510517

511-
if numSessions == 0 {
512-
return fmt.Errorf("no sessions found for "+
513-
"group ID %x", session.GroupID)
514-
}
518+
var (
519+
seqKey []byte
520+
numSessions int
521+
)
522+
err = sessionIDsBkt.ForEach(func(k, v []byte) error {
523+
numSessions++
515524

516-
if numSessions == 1 {
517-
// Delete the whole group bucket.
518-
return groupBkt.DeleteBucket(sessionIDKey)
519-
}
525+
if !bytes.Equal(v, session.ID[:]) {
526+
return nil
527+
}
520528

521-
// Else, delete just the session ID entry.
522-
return sessionIDsBkt.Delete(seqKey)
523-
})
529+
seqKey = k
530+
531+
return nil
532+
})
533+
if err != nil {
534+
return err
535+
}
536+
537+
if numSessions == 0 {
538+
return fmt.Errorf("no sessions found for "+
539+
"group ID %x", session.GroupID)
540+
}
541+
542+
if numSessions == 1 {
543+
// If this is the last session in the group, we can delete the
544+
// whole group bucket.
545+
return groupIdIndexBkt.DeleteBucket(session.GroupID[:])
546+
}
547+
548+
// Else, delete just the session ID entry from the group.
549+
return sessionIDsBkt.Delete(seqKey)
550+
}
551+
552+
// DeleteReservedSession removes a given session that is in the reserved state
553+
// from the database.
554+
//
555+
// NOTE: This is part of the Store interface.
556+
func (db *BoltStore) DeleteReservedSession(_ context.Context, id ID) error {
557+
return db.Update(func(tx *bbolt.Tx) error {
558+
sessionBucket, err := getBucket(tx, sessionBucketKey)
559+
if err != nil {
560+
return err
561+
}
562+
563+
// We'll first get the session to make sure it's actually in the
564+
// reserved state before deleting. This gives us a slightly
565+
// better error message than just trying to delete and getting a
566+
// "not found" if the session was in another state.
567+
session, err := getSessionByID(sessionBucket, id)
568+
if err != nil {
569+
return err
570+
}
571+
572+
if session.State != StateReserved {
573+
return fmt.Errorf("session not in reserved state, is "+
574+
"%v", session.State)
575+
}
576+
577+
return deleteSession(sessionBucket, session)
524578
})
525579
}
526580

session/sql_store.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ type SQLQueries interface {
4545
SetSessionGroupID(ctx context.Context, arg sqlc.SetSessionGroupIDParams) error
4646
UpdateSessionState(ctx context.Context, arg sqlc.UpdateSessionStateParams) error
4747
DeleteSessionsWithState(ctx context.Context, state int16) error
48+
DeleteSession(ctx context.Context, id int64) error
4849
GetAccountIDByAlias(ctx context.Context, alias int64) (int64, error)
4950
GetAccount(ctx context.Context, id int64) (sqlc.Account, error)
5051
}
@@ -431,6 +432,30 @@ func (s *SQLStore) DeleteReservedSessions(ctx context.Context) error {
431432
})
432433
}
433434

435+
// DeleteReservedSession removes a given session that is in the reserved state
436+
// from the database.
437+
//
438+
// NOTE: This is part of the Store interface.
439+
func (s *SQLStore) DeleteReservedSession(ctx context.Context, id ID) error {
440+
var writeTxOpts db.QueriesTxOptions
441+
return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
442+
session, err := db.GetSessionByAlias(ctx, id[:])
443+
if errors.Is(err, sql.ErrNoRows) {
444+
return fmt.Errorf("%w: unable to get session: %w",
445+
ErrSessionNotFound, err)
446+
} else if err != nil {
447+
return fmt.Errorf("unable to get session: %w", err)
448+
}
449+
450+
if State(session.State) != StateReserved {
451+
return fmt.Errorf("session not in reserved state, is "+
452+
"%v", State(session.State))
453+
}
454+
455+
return db.DeleteSession(ctx, session.ID)
456+
})
457+
}
458+
434459
// GetSessionByLocalPub fetches the session with the given local pub key.
435460
//
436461
// NOTE: This is part of the Store interface.

session/store_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ func TestBasicSessionStore(t *testing.T) {
156156
// of the sessions are reserved.
157157
require.NoError(t, db.DeleteReservedSessions(ctx))
158158

159+
// Explicitly trying to delete session 1 should fail as it's not
160+
// reserved.
161+
require.Error(t, db.DeleteReservedSession(ctx, s1.ID))
162+
159163
sessions, err = db.ListSessionsByState(ctx, StateReserved)
160164
require.NoError(t, err)
161165
require.Empty(t, sessions)
@@ -192,6 +196,28 @@ func TestBasicSessionStore(t *testing.T) {
192196
_, err = db.GetGroupID(ctx, s4.ID)
193197
require.ErrorIs(t, err, ErrSessionNotFound)
194198

199+
// Reserve a new session and link it to session 1.
200+
s5, err := reserveSession(
201+
db, "session 5", withLinkedGroupID(&session1.GroupID),
202+
)
203+
require.NoError(t, err)
204+
sessions, err = db.ListSessionsByState(ctx, StateReserved)
205+
require.NoError(t, err)
206+
require.Equal(t, 1, len(sessions))
207+
assertEqualSessions(t, s5, sessions[0])
208+
209+
// Now delete the reserved session by its ID and show that it is no
210+
// longer in the database and no longer in the group ID/session ID
211+
// index.
212+
require.NoError(t, db.DeleteReservedSession(ctx, s5.ID))
213+
214+
sessions, err = db.ListSessionsByState(ctx, StateReserved)
215+
require.NoError(t, err)
216+
require.Empty(t, sessions)
217+
218+
_, err = db.GetGroupID(ctx, s5.ID)
219+
require.ErrorIs(t, err, ErrSessionNotFound)
220+
195221
// Only session 1 should remain in this group.
196222
sessIDs, err = db.GetSessionIDs(ctx, s4.GroupID)
197223
require.NoError(t, err)

0 commit comments

Comments
 (0)