Add configurable MaxSessions limit (default: 10000) to SessionStore. When the limit is reached, new session creation returns ErrTooManySessions and HTTP transport responds with 503 Service Unavailable. This prevents attackers from exhausting server memory by creating unlimited sessions through repeated initialize requests. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
232 lines
4.9 KiB
Go
232 lines
4.9 KiB
Go
package mcp
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Session represents an MCP client session.
|
|
type Session struct {
|
|
ID string
|
|
CreatedAt time.Time
|
|
LastActivity time.Time
|
|
Initialized bool
|
|
|
|
// notifications is a channel for server-initiated notifications.
|
|
notifications chan *Response
|
|
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewSession creates a new session with a cryptographically secure random ID.
|
|
func NewSession() (*Session, error) {
|
|
id, err := generateSessionID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
now := time.Now()
|
|
return &Session{
|
|
ID: id,
|
|
CreatedAt: now,
|
|
LastActivity: now,
|
|
notifications: make(chan *Response, 100),
|
|
}, nil
|
|
}
|
|
|
|
// Touch updates the session's last activity time.
|
|
func (s *Session) Touch() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.LastActivity = time.Now()
|
|
}
|
|
|
|
// SetInitialized marks the session as initialized.
|
|
func (s *Session) SetInitialized() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.Initialized = true
|
|
}
|
|
|
|
// IsInitialized returns whether the session has been initialized.
|
|
func (s *Session) IsInitialized() bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.Initialized
|
|
}
|
|
|
|
// Notifications returns the channel for server-initiated notifications.
|
|
func (s *Session) Notifications() <-chan *Response {
|
|
return s.notifications
|
|
}
|
|
|
|
// SendNotification sends a notification to the session.
|
|
// Returns false if the channel is full.
|
|
func (s *Session) SendNotification(notification *Response) bool {
|
|
select {
|
|
case s.notifications <- notification:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Close closes the session's notification channel.
|
|
func (s *Session) Close() {
|
|
close(s.notifications)
|
|
}
|
|
|
|
// SessionStore manages active sessions with TTL-based cleanup.
|
|
type SessionStore struct {
|
|
sessions map[string]*Session
|
|
ttl time.Duration
|
|
maxSessions int
|
|
mu sync.RWMutex
|
|
stopClean chan struct{}
|
|
cleanDone chan struct{}
|
|
}
|
|
|
|
// ErrTooManySessions is returned when the session limit is reached.
|
|
var ErrTooManySessions = fmt.Errorf("too many active sessions")
|
|
|
|
// DefaultMaxSessions is the default maximum number of concurrent sessions.
|
|
const DefaultMaxSessions = 10000
|
|
|
|
// NewSessionStore creates a new session store with the given TTL.
|
|
func NewSessionStore(ttl time.Duration) *SessionStore {
|
|
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
|
|
}
|
|
|
|
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
|
|
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
|
|
if maxSessions <= 0 {
|
|
maxSessions = DefaultMaxSessions
|
|
}
|
|
s := &SessionStore{
|
|
sessions: make(map[string]*Session),
|
|
ttl: ttl,
|
|
maxSessions: maxSessions,
|
|
stopClean: make(chan struct{}),
|
|
cleanDone: make(chan struct{}),
|
|
}
|
|
go s.cleanupLoop()
|
|
return s
|
|
}
|
|
|
|
// Create creates a new session and adds it to the store.
|
|
// Returns ErrTooManySessions if the maximum session limit is reached.
|
|
func (s *SessionStore) Create() (*Session, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Check session limit
|
|
if len(s.sessions) >= s.maxSessions {
|
|
return nil, ErrTooManySessions
|
|
}
|
|
|
|
session, err := NewSession()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
s.sessions[session.ID] = session
|
|
return session, nil
|
|
}
|
|
|
|
// Get retrieves a session by ID. Returns nil if not found or expired.
|
|
func (s *SessionStore) Get(id string) *Session {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
session, ok := s.sessions[id]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// Check if expired
|
|
session.mu.RLock()
|
|
expired := time.Since(session.LastActivity) > s.ttl
|
|
session.mu.RUnlock()
|
|
|
|
if expired {
|
|
return nil
|
|
}
|
|
|
|
return session
|
|
}
|
|
|
|
// Delete removes a session from the store.
|
|
func (s *SessionStore) Delete(id string) bool {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
session, ok := s.sessions[id]
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
session.Close()
|
|
delete(s.sessions, id)
|
|
return true
|
|
}
|
|
|
|
// Count returns the number of active sessions.
|
|
func (s *SessionStore) Count() int {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return len(s.sessions)
|
|
}
|
|
|
|
// Stop stops the cleanup goroutine and waits for it to finish.
|
|
func (s *SessionStore) Stop() {
|
|
close(s.stopClean)
|
|
<-s.cleanDone
|
|
}
|
|
|
|
// cleanupLoop periodically removes expired sessions.
|
|
func (s *SessionStore) cleanupLoop() {
|
|
defer close(s.cleanDone)
|
|
|
|
ticker := time.NewTicker(s.ttl / 2)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-s.stopClean:
|
|
return
|
|
case <-ticker.C:
|
|
s.cleanup()
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanup removes expired sessions.
|
|
func (s *SessionStore) cleanup() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for id, session := range s.sessions {
|
|
session.mu.RLock()
|
|
expired := now.Sub(session.LastActivity) > s.ttl
|
|
session.mu.RUnlock()
|
|
|
|
if expired {
|
|
session.Close()
|
|
delete(s.sessions, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
// generateSessionID generates a cryptographically secure random session ID.
|
|
func generateSessionID() (string, error) {
|
|
bytes := make([]byte, 16)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|