diff --git a/internal/mcp/session.go b/internal/mcp/session.go index 0f6e238..ee555a6 100644 --- a/internal/mcp/session.go +++ b/internal/mcp/session.go @@ -3,6 +3,7 @@ package mcp import ( "crypto/rand" "encoding/hex" + "fmt" "sync" "time" ) @@ -80,34 +81,57 @@ func (s *Session) Close() { // SessionStore manages active sessions with TTL-based cleanup. type SessionStore struct { - sessions map[string]*Session - ttl time.Duration - mu sync.RWMutex - stopClean chan struct{} - cleanDone chan struct{} + sessions map[string]*Session + ttl time.Duration + maxSessions int + mu sync.RWMutex + stopClean chan struct{} + cleanDone chan struct{} } +// ErrTooManySessions is returned when the session limit is reached. +var ErrTooManySessions = fmt.Errorf("too many active sessions") + +// DefaultMaxSessions is the default maximum number of concurrent sessions. +const DefaultMaxSessions = 10000 + // NewSessionStore creates a new session store with the given TTL. func NewSessionStore(ttl time.Duration) *SessionStore { + return NewSessionStoreWithLimit(ttl, DefaultMaxSessions) +} + +// NewSessionStoreWithLimit creates a new session store with TTL and max session limit. +func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore { + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } s := &SessionStore{ - sessions: make(map[string]*Session), - ttl: ttl, - stopClean: make(chan struct{}), - cleanDone: make(chan struct{}), + sessions: make(map[string]*Session), + ttl: ttl, + maxSessions: maxSessions, + stopClean: make(chan struct{}), + cleanDone: make(chan struct{}), } go s.cleanupLoop() return s } // Create creates a new session and adds it to the store. +// Returns ErrTooManySessions if the maximum session limit is reached. func (s *SessionStore) Create() (*Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Check session limit + if len(s.sessions) >= s.maxSessions { + return nil, ErrTooManySessions + } + session, err := NewSession() if err != nil { return nil, err } - s.mu.Lock() - defer s.mu.Unlock() s.sessions[session.ID] = session return session, nil } diff --git a/internal/mcp/session_test.go b/internal/mcp/session_test.go index 3c937fa..2760ab9 100644 --- a/internal/mcp/session_test.go +++ b/internal/mcp/session_test.go @@ -245,6 +245,76 @@ func TestSessionStoreConcurrency(t *testing.T) { wg.Wait() } +func TestSessionStoreMaxSessions(t *testing.T) { + maxSessions := 5 + store := NewSessionStoreWithLimit(30*time.Minute, maxSessions) + defer store.Stop() + + // Create sessions up to limit + for i := 0; i < maxSessions; i++ { + _, err := store.Create() + if err != nil { + t.Fatalf("Failed to create session %d: %v", i, err) + } + } + + if store.Count() != maxSessions { + t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count()) + } + + // Try to create one more - should fail + _, err := store.Create() + if err != ErrTooManySessions { + t.Errorf("Expected ErrTooManySessions, got %v", err) + } + + // Count should still be at max + if store.Count() != maxSessions { + t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count()) + } +} + +func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) { + maxSessions := 3 + store := NewSessionStoreWithLimit(30*time.Minute, maxSessions) + defer store.Stop() + + // Fill up the store + sessions := make([]*Session, maxSessions) + for i := 0; i < maxSessions; i++ { + s, err := store.Create() + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sessions[i] = s + } + + // Should be full + _, err := store.Create() + if err != ErrTooManySessions { + t.Error("Expected ErrTooManySessions when full") + } + + // Delete one session + store.Delete(sessions[0].ID) + + // Should be able to create again + _, err = store.Create() + if err != nil { + t.Errorf("Should be able to create after deletion: %v", err) + } +} + +func TestSessionStoreDefaultMaxSessions(t *testing.T) { + store := NewSessionStore(30 * time.Minute) + defer store.Stop() + + // Just verify it uses the default (don't create 10000 sessions) + if store.maxSessions != DefaultMaxSessions { + t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions) + } +} + func TestGenerateSessionID(t *testing.T) { ids := make(map[string]bool) diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go index f35419b..3120573 100644 --- a/internal/mcp/transport_http.go +++ b/internal/mcp/transport_http.go @@ -17,6 +17,7 @@ type HTTPConfig struct { Endpoint string // MCP endpoint path (e.g., "/mcp") AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) SessionTTL time.Duration // Session TTL (default: 30 minutes) + MaxSessions int // Maximum concurrent sessions (default: 10000) TLSCertFile string // TLS certificate file (optional) TLSKeyFile string // TLS key file (optional) MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) @@ -55,6 +56,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.SessionTTL == 0 { config.SessionTTL = 30 * time.Minute } + if config.MaxSessions == 0 { + config.MaxSessions = DefaultMaxSessions + } if config.MaxRequestSize == 0 { config.MaxRequestSize = DefaultMaxRequestSize } @@ -74,7 +78,7 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { return &HTTPTransport{ server: server, config: config, - sessions: NewSessionStore(config.SessionTTL), + sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions), } } @@ -231,6 +235,11 @@ func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, // Create a new session session, err := t.sessions.Create() if err != nil { + if err == ErrTooManySessions { + t.server.logger.Printf("Session limit reached") + http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable) + return + } t.server.logger.Printf("Failed to create session: %v", err) http.Error(w, "Failed to create session", http.StatusInternalServerError) return diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go index 68e746c..5f01b0b 100644 --- a/internal/mcp/transport_http_test.go +++ b/internal/mcp/transport_http_test.go @@ -582,6 +582,50 @@ func TestHTTPTransportRequestBodyTooLarge(t *testing.T) { } } +func TestHTTPTransportSessionLimitReached(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + MaxSessions: 2, // Very low limit for testing + }) + + initReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodInitialize, + Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), + } + body, _ := json.Marshal(initReq) + + // Create sessions up to the limit + for i := 0; i < 2; i++ { + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request %d failed: %v", i, err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode) + } + } + + // Third request should fail with 503 + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode) + } +} + func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ MaxRequestSize: 10000, // Reasonable limit