Skip to content

Commit 48db1ad

Browse files
JAORMXclaude
andauthored
feat: Add pluggable storage backend for session management (#1989)
* feat: Add pluggable storage backend for session management Refactors the session management system to use a pluggable storage interface, enabling future support for distributed storage backends like Redis/Valkey while maintaining backward compatibility. What Changed - Introduced a Storage interface that abstracts session persistence - Refactored Manager to use the Storage interface instead of directly using sync.Map - Created LocalStorage implementation that maintains the existing in-memory behavior - Added JSON serialization support for sessions to enable future network storage - Extended Session interface with Type() and metadata methods that were already implemented in concrete types Why The previous implementation was tightly coupled to in-memory storage, making it impossible to share sessions across multiple ToolHive instances. This refactoring enables: - Horizontal scaling with shared session state - Session persistence across restarts - Future Redis/Valkey backend support without breaking changes Testing Added comprehensive unit tests covering: - LocalStorage implementation - Session serialization/deserialization - Manager with pluggable storage - All existing session types (ProxySession, SSESession, StreamableSession) All tests pass and the implementation maintains full backward compatibility. Signed-off-by: Juan Antonio Osorio <[email protected]> * Address PR feedback: fix race condition and encapsulation issues - Fix race condition in LocalStorage.Close() by collecting keys before deletion - Update Close() comment to reflect actual behavior (clears sessions, not a no-op) - Add setter methods (setTimestamps, setMetadataMap) to ProxySession for proper encapsulation - Update serialization to use setter methods instead of direct field access - Fix StreamableSession constructor to use NewTypedProxySession for proper initialization - Add type assertion check in StreamableSession deserialization 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix incorrect string conversion in test session ID generation - Replace string(rune('a' + i)) with fmt.Sprintf("session-%d", i) - Previous approach only worked for i values 0-25 and produced unexpected Unicode characters for larger values - Add fmt import to storage_test.go 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Address PR feedback: improve session storage interface design This commit addresses all feedback from PR review comments: 1. Separate Touch from Load operations - Remove auto-touch behavior from Storage.Load() - Manager.Get() now explicitly touches sessions - Gives callers control over when sessions are touched 2. Document Range/Count design decision - Add comprehensive documentation explaining why Range/Count are not part of Storage interface - These operations don't make sense for distributed storage 3. Add consistent context timeout usage - All storage operations now use consistent timeouts - 5 seconds for quick operations (Get, Delete, Add) - 30 seconds for cleanup operations 4. Proper error handling throughout - Manager.Delete() and Manager.Stop() now return errors - Cleanup routine logs errors instead of ignoring them - Proxy implementations use debug logging for non-critical errors All tests pass with these changes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> --------- Signed-off-by: Juan Antonio Osorio <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent d492159 commit 48db1ad

File tree

10 files changed

+1138
-60
lines changed

10 files changed

+1138
-60
lines changed

pkg/transport/proxy/httpsse/http_proxy.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error {
206206

207207
// Stop the session manager cleanup routine
208208
if p.sessionManager != nil {
209-
p.sessionManager.Stop()
209+
if err := p.sessionManager.Stop(); err != nil {
210+
logger.Errorf("Failed to stop session manager: %v", err)
211+
}
210212
}
211213

212214
// Disconnect all active sessions
@@ -466,7 +468,9 @@ func (p *HTTPSSEProxy) removeClient(clientID string) {
466468
}
467469

468470
// Remove the session from the manager
469-
p.sessionManager.Delete(clientID)
471+
if err := p.sessionManager.Delete(clientID); err != nil {
472+
logger.Debugf("Failed to delete session %s: %v", clientID, err)
473+
}
470474

471475
// Clean up closed clients map periodically (prevent memory leak)
472476
p.closedClientsMutex.Lock()

pkg/transport/proxy/streamable/streamable_proxy.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ func (p *HTTPProxy) Stop(ctx context.Context) error {
117117

118118
// Stop session manager cleanup and disconnect sessions
119119
if p.sessionManager != nil {
120-
p.sessionManager.Stop()
120+
if err := p.sessionManager.Stop(); err != nil {
121+
logger.Errorf("Failed to stop session manager: %v", err)
122+
}
121123
p.sessionManager.Range(func(_, value interface{}) bool {
122124
if ss, ok := value.(*session.StreamableSession); ok {
123125
ss.Disconnect()
@@ -202,7 +204,9 @@ func (p *HTTPProxy) handleDelete(w http.ResponseWriter, r *http.Request) {
202204
writeHTTPError(w, http.StatusNotFound, "session not found")
203205
return
204206
}
205-
p.sessionManager.Delete(sessID)
207+
if err := p.sessionManager.Delete(sessID); err != nil {
208+
logger.Debugf("Failed to delete session %s: %v", sessID, err)
209+
}
206210
w.WriteHeader(http.StatusNoContent)
207211
}
208212

pkg/transport/session/manager.go

Lines changed: 97 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,34 @@
22
package session
33

44
import (
5+
"context"
56
"fmt"
6-
"sync"
77
"time"
8+
9+
"github.com/stacklok/toolhive/pkg/logger"
810
)
911

10-
// Session interface
12+
// Session interface defines the contract for all session types
1113
type Session interface {
1214
ID() string
15+
Type() SessionType
1316
CreatedAt() time.Time
1417
UpdatedAt() time.Time
1518
Touch()
19+
20+
// Data and metadata methods
21+
GetData() interface{}
22+
SetData(data interface{})
23+
GetMetadata() map[string]string
24+
SetMetadata(key, value string)
1625
}
1726

1827
// Manager holds sessions with TTL cleanup.
1928
type Manager struct {
20-
sessions sync.Map
21-
ttl time.Duration
22-
stopCh chan struct{}
23-
factory Factory
29+
storage Storage
30+
ttl time.Duration
31+
stopCh chan struct{}
32+
factory Factory
2433
}
2534

2635
// Factory defines a function type for creating new sessions.
@@ -56,10 +65,10 @@ func NewManager(ttl time.Duration, factory interface{}) *Manager {
5665
}
5766

5867
m := &Manager{
59-
sessions: sync.Map{},
60-
ttl: ttl,
61-
stopCh: make(chan struct{}),
62-
factory: f,
68+
storage: NewLocalStorage(),
69+
ttl: ttl,
70+
stopCh: make(chan struct{}),
71+
factory: f,
6372
}
6473
go m.cleanupRoutine()
6574
return m
@@ -83,24 +92,30 @@ func NewTypedManager(ttl time.Duration, sessionType SessionType) *Manager {
8392
return NewManager(ttl, factory)
8493
}
8594

95+
// NewManagerWithStorage creates a session manager with a custom storage backend.
96+
func NewManagerWithStorage(ttl time.Duration, factory Factory, storage Storage) *Manager {
97+
m := &Manager{
98+
storage: storage,
99+
ttl: ttl,
100+
stopCh: make(chan struct{}),
101+
factory: factory,
102+
}
103+
go m.cleanupRoutine()
104+
return m
105+
}
106+
86107
func (m *Manager) cleanupRoutine() {
87108
ticker := time.NewTicker(m.ttl / 2)
88109
defer ticker.Stop()
89110
for {
90111
select {
91112
case <-ticker.C:
92113
cutoff := time.Now().Add(-m.ttl)
93-
m.sessions.Range(func(key, val any) bool {
94-
sess, ok := val.(Session)
95-
if !ok {
96-
// Skip invalid value
97-
return true
98-
}
99-
if sess.UpdatedAt().Before(cutoff) {
100-
m.sessions.Delete(key)
101-
}
102-
return true
103-
})
114+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
115+
if err := m.storage.DeleteExpired(ctx, cutoff); err != nil {
116+
logger.Errorf("Failed to delete expired sessions: %v", err)
117+
}
118+
cancel()
104119
case <-m.stopCh:
105120
return
106121
}
@@ -113,13 +128,17 @@ func (m *Manager) AddWithID(id string) error {
113128
if id == "" {
114129
return fmt.Errorf("session ID cannot be empty")
115130
}
116-
// Use LoadOrStore: returns existing if already present
117-
session := m.factory(id)
118-
_, loaded := m.sessions.LoadOrStore(id, session)
119-
if loaded {
131+
// Check if session already exists
132+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
133+
defer cancel()
134+
135+
if _, err := m.storage.Load(ctx, id); err == nil {
120136
return fmt.Errorf("session ID %q already exists", id)
121137
}
122-
return nil
138+
139+
// Create and store new session
140+
session := m.factory(id)
141+
return m.storage.Store(ctx, session)
123142
}
124143

125144
// AddSession adds an existing session to the manager.
@@ -132,62 +151,85 @@ func (m *Manager) AddSession(session Session) error {
132151
return fmt.Errorf("session ID cannot be empty")
133152
}
134153

135-
_, loaded := m.sessions.LoadOrStore(session.ID(), session)
136-
if loaded {
154+
// Check if session already exists
155+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
156+
defer cancel()
157+
158+
if _, err := m.storage.Load(ctx, session.ID()); err == nil {
137159
return fmt.Errorf("session ID %q already exists", session.ID())
138160
}
139-
return nil
161+
162+
return m.storage.Store(ctx, session)
140163
}
141164

142165
// Get retrieves a session by ID. Returns (session, true) if found,
143166
// and also updates its UpdatedAt timestamp.
144167
func (m *Manager) Get(id string) (Session, bool) {
145-
v, ok := m.sessions.Load(id)
146-
if !ok {
168+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
169+
defer cancel()
170+
171+
sess, err := m.storage.Load(ctx, id)
172+
if err != nil {
147173
return nil, false
148174
}
149-
sess, ok := v.(Session)
150-
if !ok {
151-
return nil, false // Invalid session type
152-
}
153-
175+
// Touch the session to update its timestamp
154176
sess.Touch()
155177
return sess, true
156178
}
157179

158180
// Delete removes a session by ID.
159-
func (m *Manager) Delete(id string) {
160-
m.sessions.Delete(id)
181+
// Returns an error if the deletion fails.
182+
func (m *Manager) Delete(id string) error {
183+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
184+
defer cancel()
185+
return m.storage.Delete(ctx, id)
161186
}
162187

163-
// Stop stops the cleanup worker.
164-
func (m *Manager) Stop() {
188+
// Stop stops the cleanup worker and closes the storage backend.
189+
// Returns an error if closing the storage backend fails.
190+
func (m *Manager) Stop() error {
165191
close(m.stopCh)
192+
if m.storage != nil {
193+
return m.storage.Close()
194+
}
195+
return nil
166196
}
167197

168198
// Range calls f sequentially for each key and value present in the map.
169199
// If f returns false, range stops the iteration.
200+
//
201+
// Note: This method only works with LocalStorage backend. It will silently
202+
// do nothing with other storage backends. Range is not part of the Storage
203+
// interface because it's not feasible for distributed storage backends like
204+
// Redis where iterating all keys can be prohibitively expensive or impractical.
205+
//
206+
// For distributed storage, consider using more targeted queries or maintaining
207+
// a separate index of session IDs.
170208
func (m *Manager) Range(f func(key, value interface{}) bool) {
171-
m.sessions.Range(f)
209+
if localStorage, ok := m.storage.(*LocalStorage); ok {
210+
localStorage.Range(f)
211+
}
172212
}
173213

174214
// Count returns the number of active sessions.
215+
//
216+
// Note: This method only works with LocalStorage backend and returns 0 for
217+
// other storage backends. Count is not part of the Storage interface because
218+
// it's not feasible for distributed storage backends like Redis where counting
219+
// all keys can be prohibitively expensive.
220+
//
221+
// For distributed storage, consider maintaining a counter or using approximate
222+
// count mechanisms provided by the storage backend.
175223
func (m *Manager) Count() int {
176-
count := 0
177-
m.sessions.Range(func(_, _ interface{}) bool {
178-
count++
179-
return true
180-
})
181-
return count
224+
if localStorage, ok := m.storage.(*LocalStorage); ok {
225+
return localStorage.Count()
226+
}
227+
return 0
182228
}
183229

184-
func (m *Manager) cleanupExpiredOnce() {
230+
func (m *Manager) cleanupExpiredOnce() error {
185231
cutoff := time.Now().Add(-m.ttl)
186-
m.sessions.Range(func(key, val any) bool {
187-
sess := val.(Session)
188-
if sess.UpdatedAt().Before(cutoff) {
189-
m.sessions.Delete(key)
190-
}
191-
return true
192-
})
232+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
233+
defer cancel()
234+
return m.storage.DeleteExpired(ctx, cutoff)
193235
}

pkg/transport/session/proxy_session.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,24 @@ func (s *ProxySession) DeleteMetadata(key string) {
134134
defer s.mu.Unlock()
135135
delete(s.metadata, key)
136136
}
137+
138+
// setTimestamps updates the created and updated timestamps.
139+
// This is used internally for deserialization to restore session state.
140+
func (s *ProxySession) setTimestamps(created, updated time.Time) {
141+
s.mu.Lock()
142+
defer s.mu.Unlock()
143+
s.created = created
144+
s.updated = updated
145+
}
146+
147+
// setMetadataMap replaces the entire metadata map.
148+
// This is used internally for deserialization to restore session state.
149+
func (s *ProxySession) setMetadataMap(metadata map[string]string) {
150+
s.mu.Lock()
151+
defer s.mu.Unlock()
152+
if metadata == nil {
153+
s.metadata = make(map[string]string)
154+
} else {
155+
s.metadata = metadata
156+
}
157+
}

0 commit comments

Comments
 (0)