feature/streamable-http-transport #1

Merged
torjus merged 8 commits from feature/streamable-http-transport into master 2026-02-03 21:23:39 +00:00
4 changed files with 159 additions and 12 deletions
Showing only changes of commit 684baf63da - Show all commits

View File

@@ -3,6 +3,7 @@ package mcp
import (
"crypto/rand"
"encoding/hex"
"fmt"
"sync"
"time"
)
@@ -80,34 +81,57 @@ func (s *Session) Close() {
// SessionStore manages active sessions with TTL-based cleanup.
type SessionStore struct {
sessions map[string]*Session
ttl time.Duration
mu sync.RWMutex
stopClean chan struct{}
cleanDone chan struct{}
sessions map[string]*Session
ttl time.Duration
maxSessions int
mu sync.RWMutex
stopClean chan struct{}
cleanDone chan struct{}
}
// ErrTooManySessions is returned when the session limit is reached.
var ErrTooManySessions = fmt.Errorf("too many active sessions")
// DefaultMaxSessions is the default maximum number of concurrent sessions.
const DefaultMaxSessions = 10000
// NewSessionStore creates a new session store with the given TTL.
func NewSessionStore(ttl time.Duration) *SessionStore {
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
}
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
if maxSessions <= 0 {
maxSessions = DefaultMaxSessions
}
s := &SessionStore{
sessions: make(map[string]*Session),
ttl: ttl,
stopClean: make(chan struct{}),
cleanDone: make(chan struct{}),
sessions: make(map[string]*Session),
ttl: ttl,
maxSessions: maxSessions,
stopClean: make(chan struct{}),
cleanDone: make(chan struct{}),
}
go s.cleanupLoop()
return s
}
// Create creates a new session and adds it to the store.
// Returns ErrTooManySessions if the maximum session limit is reached.
func (s *SessionStore) Create() (*Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Check session limit
if len(s.sessions) >= s.maxSessions {
return nil, ErrTooManySessions
}
session, err := NewSession()
if err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[session.ID] = session
return session, nil
}

View File

@@ -245,6 +245,76 @@ func TestSessionStoreConcurrency(t *testing.T) {
wg.Wait()
}
func TestSessionStoreMaxSessions(t *testing.T) {
maxSessions := 5
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Create sessions up to limit
for i := 0; i < maxSessions; i++ {
_, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session %d: %v", i, err)
}
}
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
}
// Try to create one more - should fail
_, err := store.Create()
if err != ErrTooManySessions {
t.Errorf("Expected ErrTooManySessions, got %v", err)
}
// Count should still be at max
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
}
}
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
maxSessions := 3
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Fill up the store
sessions := make([]*Session, maxSessions)
for i := 0; i < maxSessions; i++ {
s, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
sessions[i] = s
}
// Should be full
_, err := store.Create()
if err != ErrTooManySessions {
t.Error("Expected ErrTooManySessions when full")
}
// Delete one session
store.Delete(sessions[0].ID)
// Should be able to create again
_, err = store.Create()
if err != nil {
t.Errorf("Should be able to create after deletion: %v", err)
}
}
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Just verify it uses the default (don't create 10000 sessions)
if store.maxSessions != DefaultMaxSessions {
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
}
}
func TestGenerateSessionID(t *testing.T) {
ids := make(map[string]bool)

View File

@@ -17,6 +17,7 @@ type HTTPConfig struct {
Endpoint string // MCP endpoint path (e.g., "/mcp")
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
SessionTTL time.Duration // Session TTL (default: 30 minutes)
MaxSessions int // Maximum concurrent sessions (default: 10000)
TLSCertFile string // TLS certificate file (optional)
TLSKeyFile string // TLS key file (optional)
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
@@ -55,6 +56,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
if config.MaxSessions == 0 {
config.MaxSessions = DefaultMaxSessions
}
if config.MaxRequestSize == 0 {
config.MaxRequestSize = DefaultMaxRequestSize
}
@@ -74,7 +78,7 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
return &HTTPTransport{
server: server,
config: config,
sessions: NewSessionStore(config.SessionTTL),
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
}
}
@@ -231,6 +235,11 @@ func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request,
// Create a new session
session, err := t.sessions.Create()
if err != nil {
if err == ErrTooManySessions {
t.server.logger.Printf("Session limit reached")
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
return
}
t.server.logger.Printf("Failed to create session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return

View File

@@ -582,6 +582,50 @@ func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
}
}
func TestHTTPTransportSessionLimitReached(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxSessions: 2, // Very low limit for testing
})
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
// Create sessions up to the limit
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
}
}
// Third request should fail with 503
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
}
}
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 10000, // Reasonable limit