Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pkg/transport/proxy/httpsse/http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error {

// Stop the session manager cleanup routine
if p.sessionManager != nil {
p.sessionManager.Stop()
if err := p.sessionManager.Stop(); err != nil {
logger.Errorf("Failed to stop session manager: %v", err)
}
}

// Disconnect all active sessions
Expand Down Expand Up @@ -466,7 +468,9 @@ func (p *HTTPSSEProxy) removeClient(clientID string) {
}

// Remove the session from the manager
p.sessionManager.Delete(clientID)
if err := p.sessionManager.Delete(clientID); err != nil {
logger.Debugf("Failed to delete session %s: %v", clientID, err)
}

// Clean up closed clients map periodically (prevent memory leak)
p.closedClientsMutex.Lock()
Expand Down
8 changes: 6 additions & 2 deletions pkg/transport/proxy/streamable/streamable_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ func (p *HTTPProxy) Stop(ctx context.Context) error {

// Stop session manager cleanup and disconnect sessions
if p.sessionManager != nil {
p.sessionManager.Stop()
if err := p.sessionManager.Stop(); err != nil {
logger.Errorf("Failed to stop session manager: %v", err)
}
p.sessionManager.Range(func(_, value interface{}) bool {
if ss, ok := value.(*session.StreamableSession); ok {
ss.Disconnect()
Expand Down Expand Up @@ -202,7 +204,9 @@ func (p *HTTPProxy) handleDelete(w http.ResponseWriter, r *http.Request) {
writeHTTPError(w, http.StatusNotFound, "session not found")
return
}
p.sessionManager.Delete(sessID)
if err := p.sessionManager.Delete(sessID); err != nil {
logger.Debugf("Failed to delete session %s: %v", sessID, err)
}
w.WriteHeader(http.StatusNoContent)
}

Expand Down
152 changes: 97 additions & 55 deletions pkg/transport/session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,34 @@
package session

import (
"context"
"fmt"
"sync"
"time"

"github.com/stacklok/toolhive/pkg/logger"
)

// Session interface
// Session interface defines the contract for all session types
type Session interface {
ID() string
Type() SessionType
CreatedAt() time.Time
UpdatedAt() time.Time
Touch()

// Data and metadata methods
GetData() interface{}
SetData(data interface{})
GetMetadata() map[string]string
SetMetadata(key, value string)
}

// Manager holds sessions with TTL cleanup.
type Manager struct {
sessions sync.Map
ttl time.Duration
stopCh chan struct{}
factory Factory
storage Storage
ttl time.Duration
stopCh chan struct{}
factory Factory
}

// Factory defines a function type for creating new sessions.
Expand Down Expand Up @@ -56,10 +65,10 @@ func NewManager(ttl time.Duration, factory interface{}) *Manager {
}

m := &Manager{
sessions: sync.Map{},
ttl: ttl,
stopCh: make(chan struct{}),
factory: f,
storage: NewLocalStorage(),
ttl: ttl,
stopCh: make(chan struct{}),
factory: f,
}
go m.cleanupRoutine()
return m
Expand All @@ -83,24 +92,30 @@ func NewTypedManager(ttl time.Duration, sessionType SessionType) *Manager {
return NewManager(ttl, factory)
}

// NewManagerWithStorage creates a session manager with a custom storage backend.
func NewManagerWithStorage(ttl time.Duration, factory Factory, storage Storage) *Manager {
m := &Manager{
storage: storage,
ttl: ttl,
stopCh: make(chan struct{}),
factory: factory,
}
go m.cleanupRoutine()
return m
}

func (m *Manager) cleanupRoutine() {
ticker := time.NewTicker(m.ttl / 2)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cutoff := time.Now().Add(-m.ttl)
m.sessions.Range(func(key, val any) bool {
sess, ok := val.(Session)
if !ok {
// Skip invalid value
return true
}
if sess.UpdatedAt().Before(cutoff) {
m.sessions.Delete(key)
}
return true
})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if err := m.storage.DeleteExpired(ctx, cutoff); err != nil {
logger.Errorf("Failed to delete expired sessions: %v", err)
}
cancel()
case <-m.stopCh:
return
}
Expand All @@ -113,13 +128,17 @@ func (m *Manager) AddWithID(id string) error {
if id == "" {
return fmt.Errorf("session ID cannot be empty")
}
// Use LoadOrStore: returns existing if already present
session := m.factory(id)
_, loaded := m.sessions.LoadOrStore(id, session)
if loaded {
// Check if session already exists
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if _, err := m.storage.Load(ctx, id); err == nil {
return fmt.Errorf("session ID %q already exists", id)
}
return nil

// Create and store new session
session := m.factory(id)
return m.storage.Store(ctx, session)
}

// AddSession adds an existing session to the manager.
Expand All @@ -132,62 +151,85 @@ func (m *Manager) AddSession(session Session) error {
return fmt.Errorf("session ID cannot be empty")
}

_, loaded := m.sessions.LoadOrStore(session.ID(), session)
if loaded {
// Check if session already exists
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if _, err := m.storage.Load(ctx, session.ID()); err == nil {
return fmt.Errorf("session ID %q already exists", session.ID())
}
return nil

return m.storage.Store(ctx, session)
}

// Get retrieves a session by ID. Returns (session, true) if found,
// and also updates its UpdatedAt timestamp.
func (m *Manager) Get(id string) (Session, bool) {
v, ok := m.sessions.Load(id)
if !ok {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

sess, err := m.storage.Load(ctx, id)
if err != nil {
return nil, false
}
sess, ok := v.(Session)
if !ok {
return nil, false // Invalid session type
}

// Touch the session to update its timestamp
sess.Touch()
return sess, true
}

// Delete removes a session by ID.
func (m *Manager) Delete(id string) {
m.sessions.Delete(id)
// Returns an error if the deletion fails.
func (m *Manager) Delete(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return m.storage.Delete(ctx, id)
}

// Stop stops the cleanup worker.
func (m *Manager) Stop() {
// Stop stops the cleanup worker and closes the storage backend.
// Returns an error if closing the storage backend fails.
func (m *Manager) Stop() error {
close(m.stopCh)
if m.storage != nil {
return m.storage.Close()
}
return nil
}

// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Note: This method only works with LocalStorage backend. It will silently
// do nothing with other storage backends. Range is not part of the Storage
// interface because it's not feasible for distributed storage backends like
// Redis where iterating all keys can be prohibitively expensive or impractical.
//
// For distributed storage, consider using more targeted queries or maintaining
// a separate index of session IDs.
func (m *Manager) Range(f func(key, value interface{}) bool) {
m.sessions.Range(f)
if localStorage, ok := m.storage.(*LocalStorage); ok {
localStorage.Range(f)
}
}

// Count returns the number of active sessions.
//
// Note: This method only works with LocalStorage backend and returns 0 for
// other storage backends. Count is not part of the Storage interface because
// it's not feasible for distributed storage backends like Redis where counting
// all keys can be prohibitively expensive.
//
// For distributed storage, consider maintaining a counter or using approximate
// count mechanisms provided by the storage backend.
func (m *Manager) Count() int {
count := 0
m.sessions.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
if localStorage, ok := m.storage.(*LocalStorage); ok {
return localStorage.Count()
}
return 0
}

func (m *Manager) cleanupExpiredOnce() {
func (m *Manager) cleanupExpiredOnce() error {
cutoff := time.Now().Add(-m.ttl)
m.sessions.Range(func(key, val any) bool {
sess := val.(Session)
if sess.UpdatedAt().Before(cutoff) {
m.sessions.Delete(key)
}
return true
})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return m.storage.DeleteExpired(ctx, cutoff)
}
21 changes: 21 additions & 0 deletions pkg/transport/session/proxy_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,24 @@ func (s *ProxySession) DeleteMetadata(key string) {
defer s.mu.Unlock()
delete(s.metadata, key)
}

// setTimestamps updates the created and updated timestamps.
// This is used internally for deserialization to restore session state.
func (s *ProxySession) setTimestamps(created, updated time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
s.created = created
s.updated = updated
}

// setMetadataMap replaces the entire metadata map.
// This is used internally for deserialization to restore session state.
func (s *ProxySession) setMetadataMap(metadata map[string]string) {
s.mu.Lock()
defer s.mu.Unlock()
if metadata == nil {
s.metadata = make(map[string]string)
} else {
s.metadata = metadata
}
}
Loading
Loading