security: add maximum session limit to prevent memory exhaustion

Add configurable MaxSessions limit (default: 10000) to SessionStore.
When the limit is reached, new session creation returns ErrTooManySessions
and HTTP transport responds with 503 Service Unavailable.

This prevents attackers from exhausting server memory by creating
unlimited sessions through repeated initialize requests.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-03 22:07:51 +01:00
parent 1565cb5e1b
commit 684baf63da
4 changed files with 159 additions and 12 deletions

View File

@@ -3,6 +3,7 @@ package mcp
import ( import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt"
"sync" "sync"
"time" "time"
) )
@@ -80,34 +81,57 @@ func (s *Session) Close() {
// SessionStore manages active sessions with TTL-based cleanup. // SessionStore manages active sessions with TTL-based cleanup.
type SessionStore struct { type SessionStore struct {
sessions map[string]*Session sessions map[string]*Session
ttl time.Duration ttl time.Duration
mu sync.RWMutex maxSessions int
stopClean chan struct{} mu sync.RWMutex
cleanDone chan struct{} stopClean chan struct{}
cleanDone chan struct{}
} }
// ErrTooManySessions is returned when the session limit is reached.
var ErrTooManySessions = fmt.Errorf("too many active sessions")
// DefaultMaxSessions is the default maximum number of concurrent sessions.
const DefaultMaxSessions = 10000
// NewSessionStore creates a new session store with the given TTL. // NewSessionStore creates a new session store with the given TTL.
func NewSessionStore(ttl time.Duration) *SessionStore { func NewSessionStore(ttl time.Duration) *SessionStore {
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
}
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
if maxSessions <= 0 {
maxSessions = DefaultMaxSessions
}
s := &SessionStore{ s := &SessionStore{
sessions: make(map[string]*Session), sessions: make(map[string]*Session),
ttl: ttl, ttl: ttl,
stopClean: make(chan struct{}), maxSessions: maxSessions,
cleanDone: make(chan struct{}), stopClean: make(chan struct{}),
cleanDone: make(chan struct{}),
} }
go s.cleanupLoop() go s.cleanupLoop()
return s return s
} }
// Create creates a new session and adds it to the store. // Create creates a new session and adds it to the store.
// Returns ErrTooManySessions if the maximum session limit is reached.
func (s *SessionStore) Create() (*Session, error) { func (s *SessionStore) Create() (*Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Check session limit
if len(s.sessions) >= s.maxSessions {
return nil, ErrTooManySessions
}
session, err := NewSession() session, err := NewSession()
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[session.ID] = session s.sessions[session.ID] = session
return session, nil return session, nil
} }

View File

@@ -245,6 +245,76 @@ func TestSessionStoreConcurrency(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestSessionStoreMaxSessions(t *testing.T) {
maxSessions := 5
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Create sessions up to limit
for i := 0; i < maxSessions; i++ {
_, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session %d: %v", i, err)
}
}
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
}
// Try to create one more - should fail
_, err := store.Create()
if err != ErrTooManySessions {
t.Errorf("Expected ErrTooManySessions, got %v", err)
}
// Count should still be at max
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
}
}
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
maxSessions := 3
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Fill up the store
sessions := make([]*Session, maxSessions)
for i := 0; i < maxSessions; i++ {
s, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
sessions[i] = s
}
// Should be full
_, err := store.Create()
if err != ErrTooManySessions {
t.Error("Expected ErrTooManySessions when full")
}
// Delete one session
store.Delete(sessions[0].ID)
// Should be able to create again
_, err = store.Create()
if err != nil {
t.Errorf("Should be able to create after deletion: %v", err)
}
}
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Just verify it uses the default (don't create 10000 sessions)
if store.maxSessions != DefaultMaxSessions {
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
}
}
func TestGenerateSessionID(t *testing.T) { func TestGenerateSessionID(t *testing.T) {
ids := make(map[string]bool) ids := make(map[string]bool)

View File

@@ -17,6 +17,7 @@ type HTTPConfig struct {
Endpoint string // MCP endpoint path (e.g., "/mcp") Endpoint string // MCP endpoint path (e.g., "/mcp")
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
SessionTTL time.Duration // Session TTL (default: 30 minutes) SessionTTL time.Duration // Session TTL (default: 30 minutes)
MaxSessions int // Maximum concurrent sessions (default: 10000)
TLSCertFile string // TLS certificate file (optional) TLSCertFile string // TLS certificate file (optional)
TLSKeyFile string // TLS key file (optional) TLSKeyFile string // TLS key file (optional)
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
@@ -55,6 +56,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
if config.SessionTTL == 0 { if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute config.SessionTTL = 30 * time.Minute
} }
if config.MaxSessions == 0 {
config.MaxSessions = DefaultMaxSessions
}
if config.MaxRequestSize == 0 { if config.MaxRequestSize == 0 {
config.MaxRequestSize = DefaultMaxRequestSize config.MaxRequestSize = DefaultMaxRequestSize
} }
@@ -74,7 +78,7 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
return &HTTPTransport{ return &HTTPTransport{
server: server, server: server,
config: config, config: config,
sessions: NewSessionStore(config.SessionTTL), sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
} }
} }
@@ -231,6 +235,11 @@ func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request,
// Create a new session // Create a new session
session, err := t.sessions.Create() session, err := t.sessions.Create()
if err != nil { if err != nil {
if err == ErrTooManySessions {
t.server.logger.Printf("Session limit reached")
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
return
}
t.server.logger.Printf("Failed to create session: %v", err) t.server.logger.Printf("Failed to create session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError) http.Error(w, "Failed to create session", http.StatusInternalServerError)
return return

View File

@@ -582,6 +582,50 @@ func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
} }
} }
func TestHTTPTransportSessionLimitReached(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxSessions: 2, // Very low limit for testing
})
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
// Create sessions up to the limit
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
}
}
// Third request should fail with 503
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
}
}
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) { func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{ _, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 10000, // Reasonable limit MaxRequestSize: 10000, // Reasonable limit