From 684baf63da1c92178ae5b5c2d9f0faad1eb93f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Tue, 3 Feb 2026 22:07:51 +0100 Subject: [PATCH] security: add maximum session limit to prevent memory exhaustion Add configurable MaxSessions limit (default: 10000) to SessionStore. When the limit is reached, new session creation returns ErrTooManySessions and HTTP transport responds with 503 Service Unavailable. This prevents attackers from exhausting server memory by creating unlimited sessions through repeated initialize requests. Co-Authored-By: Claude Opus 4.5 --- internal/mcp/session.go | 46 ++++++++++++++----- internal/mcp/session_test.go | 70 +++++++++++++++++++++++++++++ internal/mcp/transport_http.go | 11 ++++- internal/mcp/transport_http_test.go | 44 ++++++++++++++++++ 4 files changed, 159 insertions(+), 12 deletions(-) diff --git a/internal/mcp/session.go b/internal/mcp/session.go index 0f6e238..ee555a6 100644 --- a/internal/mcp/session.go +++ b/internal/mcp/session.go @@ -3,6 +3,7 @@ package mcp import ( "crypto/rand" "encoding/hex" + "fmt" "sync" "time" ) @@ -80,34 +81,57 @@ func (s *Session) Close() { // SessionStore manages active sessions with TTL-based cleanup. type SessionStore struct { - sessions map[string]*Session - ttl time.Duration - mu sync.RWMutex - stopClean chan struct{} - cleanDone chan struct{} + sessions map[string]*Session + ttl time.Duration + maxSessions int + mu sync.RWMutex + stopClean chan struct{} + cleanDone chan struct{} } +// ErrTooManySessions is returned when the session limit is reached. +var ErrTooManySessions = fmt.Errorf("too many active sessions") + +// DefaultMaxSessions is the default maximum number of concurrent sessions. +const DefaultMaxSessions = 10000 + // NewSessionStore creates a new session store with the given TTL. func NewSessionStore(ttl time.Duration) *SessionStore { + return NewSessionStoreWithLimit(ttl, DefaultMaxSessions) +} + +// NewSessionStoreWithLimit creates a new session store with TTL and max session limit. +func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore { + if maxSessions <= 0 { + maxSessions = DefaultMaxSessions + } s := &SessionStore{ - sessions: make(map[string]*Session), - ttl: ttl, - stopClean: make(chan struct{}), - cleanDone: make(chan struct{}), + sessions: make(map[string]*Session), + ttl: ttl, + maxSessions: maxSessions, + stopClean: make(chan struct{}), + cleanDone: make(chan struct{}), } go s.cleanupLoop() return s } // Create creates a new session and adds it to the store. +// Returns ErrTooManySessions if the maximum session limit is reached. func (s *SessionStore) Create() (*Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Check session limit + if len(s.sessions) >= s.maxSessions { + return nil, ErrTooManySessions + } + session, err := NewSession() if err != nil { return nil, err } - s.mu.Lock() - defer s.mu.Unlock() s.sessions[session.ID] = session return session, nil } diff --git a/internal/mcp/session_test.go b/internal/mcp/session_test.go index 3c937fa..2760ab9 100644 --- a/internal/mcp/session_test.go +++ b/internal/mcp/session_test.go @@ -245,6 +245,76 @@ func TestSessionStoreConcurrency(t *testing.T) { wg.Wait() } +func TestSessionStoreMaxSessions(t *testing.T) { + maxSessions := 5 + store := NewSessionStoreWithLimit(30*time.Minute, maxSessions) + defer store.Stop() + + // Create sessions up to limit + for i := 0; i < maxSessions; i++ { + _, err := store.Create() + if err != nil { + t.Fatalf("Failed to create session %d: %v", i, err) + } + } + + if store.Count() != maxSessions { + t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count()) + } + + // Try to create one more - should fail + _, err := store.Create() + if err != ErrTooManySessions { + t.Errorf("Expected ErrTooManySessions, got %v", err) + } + + // Count should still be at max + if store.Count() != maxSessions { + t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count()) + } +} + +func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) { + maxSessions := 3 + store := NewSessionStoreWithLimit(30*time.Minute, maxSessions) + defer store.Stop() + + // Fill up the store + sessions := make([]*Session, maxSessions) + for i := 0; i < maxSessions; i++ { + s, err := store.Create() + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sessions[i] = s + } + + // Should be full + _, err := store.Create() + if err != ErrTooManySessions { + t.Error("Expected ErrTooManySessions when full") + } + + // Delete one session + store.Delete(sessions[0].ID) + + // Should be able to create again + _, err = store.Create() + if err != nil { + t.Errorf("Should be able to create after deletion: %v", err) + } +} + +func TestSessionStoreDefaultMaxSessions(t *testing.T) { + store := NewSessionStore(30 * time.Minute) + defer store.Stop() + + // Just verify it uses the default (don't create 10000 sessions) + if store.maxSessions != DefaultMaxSessions { + t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions) + } +} + func TestGenerateSessionID(t *testing.T) { ids := make(map[string]bool) diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go index f35419b..3120573 100644 --- a/internal/mcp/transport_http.go +++ b/internal/mcp/transport_http.go @@ -17,6 +17,7 @@ type HTTPConfig struct { Endpoint string // MCP endpoint path (e.g., "/mcp") AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) SessionTTL time.Duration // Session TTL (default: 30 minutes) + MaxSessions int // Maximum concurrent sessions (default: 10000) TLSCertFile string // TLS certificate file (optional) TLSKeyFile string // TLS key file (optional) MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) @@ -55,6 +56,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.SessionTTL == 0 { config.SessionTTL = 30 * time.Minute } + if config.MaxSessions == 0 { + config.MaxSessions = DefaultMaxSessions + } if config.MaxRequestSize == 0 { config.MaxRequestSize = DefaultMaxRequestSize } @@ -74,7 +78,7 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { return &HTTPTransport{ server: server, config: config, - sessions: NewSessionStore(config.SessionTTL), + sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions), } } @@ -231,6 +235,11 @@ func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, // Create a new session session, err := t.sessions.Create() if err != nil { + if err == ErrTooManySessions { + t.server.logger.Printf("Session limit reached") + http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable) + return + } t.server.logger.Printf("Failed to create session: %v", err) http.Error(w, "Failed to create session", http.StatusInternalServerError) return diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go index 68e746c..5f01b0b 100644 --- a/internal/mcp/transport_http_test.go +++ b/internal/mcp/transport_http_test.go @@ -582,6 +582,50 @@ func TestHTTPTransportRequestBodyTooLarge(t *testing.T) { } } +func TestHTTPTransportSessionLimitReached(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + MaxSessions: 2, // Very low limit for testing + }) + + initReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodInitialize, + Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), + } + body, _ := json.Marshal(initReq) + + // Create sessions up to the limit + for i := 0; i < 2; i++ { + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request %d failed: %v", i, err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode) + } + } + + // Third request should fail with 503 + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusServiceUnavailable { + t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode) + } +} + func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ MaxRequestSize: 10000, // Reasonable limit