Add configurable MaxSessions limit (default: 10000) to SessionStore. When the limit is reached, new session creation returns ErrTooManySessions and HTTP transport responds with 503 Service Unavailable. This prevents attackers from exhausting server memory by creating unlimited sessions through repeated initialize requests. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
415 lines
11 KiB
Go
415 lines
11 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// HTTPConfig configures the HTTP transport.
|
|
type HTTPConfig struct {
|
|
Address string // Listen address (e.g., "127.0.0.1:8080")
|
|
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
|
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
|
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
|
MaxSessions int // Maximum concurrent sessions (default: 10000)
|
|
TLSCertFile string // TLS certificate file (optional)
|
|
TLSKeyFile string // TLS key file (optional)
|
|
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
|
ReadTimeout time.Duration // HTTP server read timeout (default: 30s)
|
|
WriteTimeout time.Duration // HTTP server write timeout (default: 30s)
|
|
IdleTimeout time.Duration // HTTP server idle timeout (default: 120s)
|
|
ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s)
|
|
}
|
|
|
|
const (
|
|
// DefaultMaxRequestSize is the default maximum request body size (1MB).
|
|
DefaultMaxRequestSize = 1 << 20 // 1MB
|
|
|
|
// Default HTTP server timeouts
|
|
DefaultReadTimeout = 30 * time.Second
|
|
DefaultWriteTimeout = 30 * time.Second
|
|
DefaultIdleTimeout = 120 * time.Second
|
|
DefaultReadHeaderTimeout = 10 * time.Second
|
|
)
|
|
|
|
// HTTPTransport implements the MCP Streamable HTTP transport.
|
|
type HTTPTransport struct {
|
|
server *Server
|
|
config HTTPConfig
|
|
sessions *SessionStore
|
|
}
|
|
|
|
// NewHTTPTransport creates a new HTTP transport.
|
|
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
|
if config.Address == "" {
|
|
config.Address = "127.0.0.1:8080"
|
|
}
|
|
if config.Endpoint == "" {
|
|
config.Endpoint = "/mcp"
|
|
}
|
|
if config.SessionTTL == 0 {
|
|
config.SessionTTL = 30 * time.Minute
|
|
}
|
|
if config.MaxSessions == 0 {
|
|
config.MaxSessions = DefaultMaxSessions
|
|
}
|
|
if config.MaxRequestSize == 0 {
|
|
config.MaxRequestSize = DefaultMaxRequestSize
|
|
}
|
|
if config.ReadTimeout == 0 {
|
|
config.ReadTimeout = DefaultReadTimeout
|
|
}
|
|
if config.WriteTimeout == 0 {
|
|
config.WriteTimeout = DefaultWriteTimeout
|
|
}
|
|
if config.IdleTimeout == 0 {
|
|
config.IdleTimeout = DefaultIdleTimeout
|
|
}
|
|
if config.ReadHeaderTimeout == 0 {
|
|
config.ReadHeaderTimeout = DefaultReadHeaderTimeout
|
|
}
|
|
|
|
return &HTTPTransport{
|
|
server: server,
|
|
config: config,
|
|
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
|
|
}
|
|
}
|
|
|
|
// Run starts the HTTP server and blocks until the context is cancelled.
|
|
func (t *HTTPTransport) Run(ctx context.Context) error {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
|
|
|
|
httpServer := &http.Server{
|
|
Addr: t.config.Address,
|
|
Handler: mux,
|
|
ReadTimeout: t.config.ReadTimeout,
|
|
WriteTimeout: t.config.WriteTimeout,
|
|
IdleTimeout: t.config.IdleTimeout,
|
|
ReadHeaderTimeout: t.config.ReadHeaderTimeout,
|
|
BaseContext: func(l net.Listener) context.Context {
|
|
return ctx
|
|
},
|
|
}
|
|
|
|
// Graceful shutdown on context cancellation
|
|
go func() {
|
|
<-ctx.Done()
|
|
t.server.logger.Println("Shutting down HTTP server...")
|
|
t.sessions.Stop()
|
|
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
|
t.server.logger.Printf("HTTP server shutdown error: %v", err)
|
|
}
|
|
}()
|
|
|
|
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
|
|
|
|
var err error
|
|
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
|
|
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
|
|
} else {
|
|
err = httpServer.ListenAndServe()
|
|
}
|
|
|
|
if err == http.ErrServerClosed {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// handleMCP routes requests based on HTTP method.
|
|
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
|
|
// Validate Origin header
|
|
if !t.isOriginAllowed(r) {
|
|
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
t.handlePost(w, r)
|
|
case http.MethodGet:
|
|
t.handleGet(w, r)
|
|
case http.MethodDelete:
|
|
t.handleDelete(w, r)
|
|
case http.MethodOptions:
|
|
t.handleOptions(w, r)
|
|
default:
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
}
|
|
|
|
// handlePost handles JSON-RPC requests.
|
|
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
|
// Limit request body size to prevent memory exhaustion attacks
|
|
r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize)
|
|
|
|
// Read request body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
// Check if this is a size limit error
|
|
if err.Error() == "http: request body too large" {
|
|
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Parse request to check method
|
|
var req Request
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(Response{
|
|
JSONRPC: "2.0",
|
|
Error: &Error{
|
|
Code: ParseError,
|
|
Message: "Parse error",
|
|
Data: err.Error(),
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Handle initialize request - create session
|
|
if req.Method == MethodInitialize {
|
|
t.handleInitialize(w, r, &req)
|
|
return
|
|
}
|
|
|
|
// All other requests require a valid session
|
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
session := t.sessions.Get(sessionID)
|
|
if session == nil {
|
|
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Update session activity
|
|
session.Touch()
|
|
|
|
// Handle notifications (no response expected)
|
|
if req.Method == MethodInitialized {
|
|
session.SetInitialized()
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
// Check if this is a notification (no ID)
|
|
if req.ID == nil {
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
// Process the request
|
|
resp := t.server.HandleRequest(r.Context(), &req)
|
|
if resp == nil {
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
// handleInitialize handles the initialize request and creates a new session.
|
|
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
|
|
// Create a new session
|
|
session, err := t.sessions.Create()
|
|
if err != nil {
|
|
if err == ErrTooManySessions {
|
|
t.server.logger.Printf("Session limit reached")
|
|
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
t.server.logger.Printf("Failed to create session: %v", err)
|
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Process initialize request
|
|
resp := t.server.HandleRequest(r.Context(), req)
|
|
if resp == nil {
|
|
t.sessions.Delete(session.ID)
|
|
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// If initialize failed, clean up session
|
|
if resp.Error != nil {
|
|
t.sessions.Delete(session.ID)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Mcp-Session-Id", session.ID)
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
// handleGet handles SSE stream for server-initiated notifications.
|
|
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
|
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
session := t.sessions.Get(sessionID)
|
|
if session == nil {
|
|
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Check if client accepts SSE
|
|
accept := r.Header.Get("Accept")
|
|
if !strings.Contains(accept, "text/event-stream") {
|
|
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
|
|
return
|
|
}
|
|
|
|
// Set SSE headers
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Flush headers
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
|
|
// Use ResponseController to manage write deadlines for long-lived SSE connections
|
|
rc := http.NewResponseController(w)
|
|
|
|
// Stream notifications
|
|
ctx := r.Context()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case notification, ok := <-session.Notifications():
|
|
if !ok {
|
|
// Session closed
|
|
return
|
|
}
|
|
|
|
// Extend write deadline before each write
|
|
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
|
t.server.logger.Printf("Failed to set write deadline: %v", err)
|
|
}
|
|
|
|
data, err := json.Marshal(notification)
|
|
if err != nil {
|
|
t.server.logger.Printf("Failed to marshal notification: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Write SSE event
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
flusher.Flush()
|
|
|
|
// Touch session to keep it alive
|
|
session.Touch()
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleDelete terminates a session.
|
|
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
|
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if t.sessions.Delete(sessionID) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
} else {
|
|
http.Error(w, "Session not found", http.StatusNotFound)
|
|
}
|
|
}
|
|
|
|
// handleOptions handles CORS preflight requests.
|
|
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if origin != "" && t.isOriginAllowed(r) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
|
|
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// isOriginAllowed checks if the request origin is allowed.
|
|
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
// No Origin header (same-origin request) is always allowed
|
|
if origin == "" {
|
|
return true
|
|
}
|
|
|
|
// If no allowed origins configured, only allow localhost
|
|
if len(t.config.AllowedOrigins) == 0 {
|
|
return isLocalhostOrigin(origin)
|
|
}
|
|
|
|
// Check against allowed origins
|
|
for _, allowed := range t.config.AllowedOrigins {
|
|
if allowed == "*" || allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isLocalhostOrigin checks if the origin is a localhost address.
|
|
func isLocalhostOrigin(origin string) bool {
|
|
origin = strings.ToLower(origin)
|
|
|
|
// Check for localhost patterns (must be followed by :, /, or end of string)
|
|
localhostPatterns := []string{
|
|
"http://localhost",
|
|
"https://localhost",
|
|
"http://127.0.0.1",
|
|
"https://127.0.0.1",
|
|
"http://[::1]",
|
|
"https://[::1]",
|
|
}
|
|
|
|
for _, pattern := range localhostPatterns {
|
|
if origin == pattern {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|