security: add maximum session limit to prevent memory exhaustion
Add configurable MaxSessions limit (default: 10000) to SessionStore. When the limit is reached, new session creation returns ErrTooManySessions and HTTP transport responds with 503 Service Unavailable. This prevents attackers from exhausting server memory by creating unlimited sessions through repeated initialize requests. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@ package mcp
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -82,16 +83,32 @@ func (s *Session) Close() {
|
|||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
sessions map[string]*Session
|
sessions map[string]*Session
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
|
maxSessions int
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
stopClean chan struct{}
|
stopClean chan struct{}
|
||||||
cleanDone chan struct{}
|
cleanDone chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrTooManySessions is returned when the session limit is reached.
|
||||||
|
var ErrTooManySessions = fmt.Errorf("too many active sessions")
|
||||||
|
|
||||||
|
// DefaultMaxSessions is the default maximum number of concurrent sessions.
|
||||||
|
const DefaultMaxSessions = 10000
|
||||||
|
|
||||||
// NewSessionStore creates a new session store with the given TTL.
|
// NewSessionStore creates a new session store with the given TTL.
|
||||||
func NewSessionStore(ttl time.Duration) *SessionStore {
|
func NewSessionStore(ttl time.Duration) *SessionStore {
|
||||||
|
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
|
||||||
|
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
|
||||||
|
if maxSessions <= 0 {
|
||||||
|
maxSessions = DefaultMaxSessions
|
||||||
|
}
|
||||||
s := &SessionStore{
|
s := &SessionStore{
|
||||||
sessions: make(map[string]*Session),
|
sessions: make(map[string]*Session),
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
|
maxSessions: maxSessions,
|
||||||
stopClean: make(chan struct{}),
|
stopClean: make(chan struct{}),
|
||||||
cleanDone: make(chan struct{}),
|
cleanDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
@@ -100,14 +117,21 @@ func NewSessionStore(ttl time.Duration) *SessionStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new session and adds it to the store.
|
// Create creates a new session and adds it to the store.
|
||||||
|
// Returns ErrTooManySessions if the maximum session limit is reached.
|
||||||
func (s *SessionStore) Create() (*Session, error) {
|
func (s *SessionStore) Create() (*Session, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Check session limit
|
||||||
|
if len(s.sessions) >= s.maxSessions {
|
||||||
|
return nil, ErrTooManySessions
|
||||||
|
}
|
||||||
|
|
||||||
session, err := NewSession()
|
session, err := NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.sessions[session.ID] = session
|
s.sessions[session.ID] = session
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -245,6 +245,76 @@ func TestSessionStoreConcurrency(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreMaxSessions(t *testing.T) {
|
||||||
|
maxSessions := 5
|
||||||
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Create sessions up to limit
|
||||||
|
for i := 0; i < maxSessions; i++ {
|
||||||
|
_, err := store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create session %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != maxSessions {
|
||||||
|
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to create one more - should fail
|
||||||
|
_, err := store.Create()
|
||||||
|
if err != ErrTooManySessions {
|
||||||
|
t.Errorf("Expected ErrTooManySessions, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count should still be at max
|
||||||
|
if store.Count() != maxSessions {
|
||||||
|
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
|
||||||
|
maxSessions := 3
|
||||||
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Fill up the store
|
||||||
|
sessions := make([]*Session, maxSessions)
|
||||||
|
for i := 0; i < maxSessions; i++ {
|
||||||
|
s, err := store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create session: %v", err)
|
||||||
|
}
|
||||||
|
sessions[i] = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be full
|
||||||
|
_, err := store.Create()
|
||||||
|
if err != ErrTooManySessions {
|
||||||
|
t.Error("Expected ErrTooManySessions when full")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete one session
|
||||||
|
store.Delete(sessions[0].ID)
|
||||||
|
|
||||||
|
// Should be able to create again
|
||||||
|
_, err = store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should be able to create after deletion: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
|
||||||
|
store := NewSessionStore(30 * time.Minute)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Just verify it uses the default (don't create 10000 sessions)
|
||||||
|
if store.maxSessions != DefaultMaxSessions {
|
||||||
|
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateSessionID(t *testing.T) {
|
func TestGenerateSessionID(t *testing.T) {
|
||||||
ids := make(map[string]bool)
|
ids := make(map[string]bool)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type HTTPConfig struct {
|
|||||||
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
||||||
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
||||||
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
||||||
|
MaxSessions int // Maximum concurrent sessions (default: 10000)
|
||||||
TLSCertFile string // TLS certificate file (optional)
|
TLSCertFile string // TLS certificate file (optional)
|
||||||
TLSKeyFile string // TLS key file (optional)
|
TLSKeyFile string // TLS key file (optional)
|
||||||
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
||||||
@@ -55,6 +56,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
|||||||
if config.SessionTTL == 0 {
|
if config.SessionTTL == 0 {
|
||||||
config.SessionTTL = 30 * time.Minute
|
config.SessionTTL = 30 * time.Minute
|
||||||
}
|
}
|
||||||
|
if config.MaxSessions == 0 {
|
||||||
|
config.MaxSessions = DefaultMaxSessions
|
||||||
|
}
|
||||||
if config.MaxRequestSize == 0 {
|
if config.MaxRequestSize == 0 {
|
||||||
config.MaxRequestSize = DefaultMaxRequestSize
|
config.MaxRequestSize = DefaultMaxRequestSize
|
||||||
}
|
}
|
||||||
@@ -74,7 +78,7 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
|||||||
return &HTTPTransport{
|
return &HTTPTransport{
|
||||||
server: server,
|
server: server,
|
||||||
config: config,
|
config: config,
|
||||||
sessions: NewSessionStore(config.SessionTTL),
|
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,6 +235,11 @@ func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request,
|
|||||||
// Create a new session
|
// Create a new session
|
||||||
session, err := t.sessions.Create()
|
session, err := t.sessions.Create()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == ErrTooManySessions {
|
||||||
|
t.server.logger.Printf("Session limit reached")
|
||||||
|
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
t.server.logger.Printf("Failed to create session: %v", err)
|
t.server.logger.Printf("Failed to create session: %v", err)
|
||||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -582,6 +582,50 @@ func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSessionLimitReached(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
MaxSessions: 2, // Very low limit for testing
|
||||||
|
})
|
||||||
|
|
||||||
|
initReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodInitialize,
|
||||||
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(initReq)
|
||||||
|
|
||||||
|
// Create sessions up to the limit
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request %d failed: %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third request should fail with 503
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
||||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
MaxRequestSize: 10000, // Reasonable limit
|
MaxRequestSize: 10000, // Reasonable limit
|
||||||
|
|||||||
Reference in New Issue
Block a user