feature/streamable-http-transport #1
@@ -3,6 +3,7 @@ package mcp
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -80,34 +81,57 @@ func (s *Session) Close() {
|
||||
|
||||
// SessionStore manages active sessions with TTL-based cleanup.
|
||||
type SessionStore struct {
|
||||
sessions map[string]*Session
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
stopClean chan struct{}
|
||||
cleanDone chan struct{}
|
||||
sessions map[string]*Session
|
||||
ttl time.Duration
|
||||
maxSessions int
|
||||
mu sync.RWMutex
|
||||
stopClean chan struct{}
|
||||
cleanDone chan struct{}
|
||||
}
|
||||
|
||||
// ErrTooManySessions is returned when the session limit is reached.
|
||||
var ErrTooManySessions = fmt.Errorf("too many active sessions")
|
||||
|
||||
// DefaultMaxSessions is the default maximum number of concurrent sessions.
|
||||
const DefaultMaxSessions = 10000
|
||||
|
||||
// NewSessionStore creates a new session store with the given TTL.
|
||||
func NewSessionStore(ttl time.Duration) *SessionStore {
|
||||
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
|
||||
}
|
||||
|
||||
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
|
||||
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
|
||||
if maxSessions <= 0 {
|
||||
maxSessions = DefaultMaxSessions
|
||||
}
|
||||
s := &SessionStore{
|
||||
sessions: make(map[string]*Session),
|
||||
ttl: ttl,
|
||||
stopClean: make(chan struct{}),
|
||||
cleanDone: make(chan struct{}),
|
||||
sessions: make(map[string]*Session),
|
||||
ttl: ttl,
|
||||
maxSessions: maxSessions,
|
||||
stopClean: make(chan struct{}),
|
||||
cleanDone: make(chan struct{}),
|
||||
}
|
||||
go s.cleanupLoop()
|
||||
return s
|
||||
}
|
||||
|
||||
// Create creates a new session and adds it to the store.
|
||||
// Returns ErrTooManySessions if the maximum session limit is reached.
|
||||
func (s *SessionStore) Create() (*Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Check session limit
|
||||
if len(s.sessions) >= s.maxSessions {
|
||||
return nil, ErrTooManySessions
|
||||
}
|
||||
|
||||
session, err := NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[session.ID] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
@@ -245,6 +245,76 @@ func TestSessionStoreConcurrency(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSessionStoreMaxSessions(t *testing.T) {
|
||||
maxSessions := 5
|
||||
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
||||
defer store.Stop()
|
||||
|
||||
// Create sessions up to limit
|
||||
for i := 0; i < maxSessions; i++ {
|
||||
_, err := store.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if store.Count() != maxSessions {
|
||||
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
|
||||
}
|
||||
|
||||
// Try to create one more - should fail
|
||||
_, err := store.Create()
|
||||
if err != ErrTooManySessions {
|
||||
t.Errorf("Expected ErrTooManySessions, got %v", err)
|
||||
}
|
||||
|
||||
// Count should still be at max
|
||||
if store.Count() != maxSessions {
|
||||
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
|
||||
maxSessions := 3
|
||||
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
||||
defer store.Stop()
|
||||
|
||||
// Fill up the store
|
||||
sessions := make([]*Session, maxSessions)
|
||||
for i := 0; i < maxSessions; i++ {
|
||||
s, err := store.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
sessions[i] = s
|
||||
}
|
||||
|
||||
// Should be full
|
||||
_, err := store.Create()
|
||||
if err != ErrTooManySessions {
|
||||
t.Error("Expected ErrTooManySessions when full")
|
||||
}
|
||||
|
||||
// Delete one session
|
||||
store.Delete(sessions[0].ID)
|
||||
|
||||
// Should be able to create again
|
||||
_, err = store.Create()
|
||||
if err != nil {
|
||||
t.Errorf("Should be able to create after deletion: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
|
||||
store := NewSessionStore(30 * time.Minute)
|
||||
defer store.Stop()
|
||||
|
||||
// Just verify it uses the default (don't create 10000 sessions)
|
||||
if store.maxSessions != DefaultMaxSessions {
|
||||
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ type HTTPConfig struct {
|
||||
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
||||
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
||||
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
||||
MaxSessions int // Maximum concurrent sessions (default: 10000)
|
||||
TLSCertFile string // TLS certificate file (optional)
|
||||
TLSKeyFile string // TLS key file (optional)
|
||||
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
||||
@@ -55,6 +56,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
||||
if config.SessionTTL == 0 {
|
||||
config.SessionTTL = 30 * time.Minute
|
||||
}
|
||||
if config.MaxSessions == 0 {
|
||||
config.MaxSessions = DefaultMaxSessions
|
||||
}
|
||||
if config.MaxRequestSize == 0 {
|
||||
config.MaxRequestSize = DefaultMaxRequestSize
|
||||
}
|
||||
@@ -74,7 +78,7 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
||||
return &HTTPTransport{
|
||||
server: server,
|
||||
config: config,
|
||||
sessions: NewSessionStore(config.SessionTTL),
|
||||
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,6 +235,11 @@ func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request,
|
||||
// Create a new session
|
||||
session, err := t.sessions.Create()
|
||||
if err != nil {
|
||||
if err == ErrTooManySessions {
|
||||
t.server.logger.Printf("Session limit reached")
|
||||
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
t.server.logger.Printf("Failed to create session: %v", err)
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -582,6 +582,50 @@ func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSessionLimitReached(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||
MaxSessions: 2, // Very low limit for testing
|
||||
})
|
||||
|
||||
initReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodInitialize,
|
||||
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||
}
|
||||
body, _ := json.Marshal(initReq)
|
||||
|
||||
// Create sessions up to the limit
|
||||
for i := 0; i < 2; i++ {
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request %d failed: %v", i, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// Third request should fail with 503
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||
MaxRequestSize: 10000, // Reasonable limit
|
||||
|
||||
Reference in New Issue
Block a user