Add configurable SSEKeepAlive interval (default: 15s) that sends SSE comment lines (`:keepalive`) to maintain connection health. Benefits: - Keeps connections alive through proxies/load balancers that timeout idle connections - Detects stale connections earlier (write failures terminate the handler) - Standard SSE pattern - comments are ignored by compliant clients Configuration: - SSEKeepAlive > 0: send keepalives at specified interval - SSEKeepAlive = 0: use default (15s) - SSEKeepAlive < 0: disable keepalives Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
449 lines
12 KiB
Go
449 lines
12 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// HTTPConfig configures the HTTP transport.
|
|
type HTTPConfig struct {
|
|
Address string // Listen address (e.g., "127.0.0.1:8080")
|
|
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
|
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
|
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
|
MaxSessions int // Maximum concurrent sessions (default: 10000)
|
|
TLSCertFile string // TLS certificate file (optional)
|
|
TLSKeyFile string // TLS key file (optional)
|
|
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
|
ReadTimeout time.Duration // HTTP server read timeout (default: 30s)
|
|
WriteTimeout time.Duration // HTTP server write timeout (default: 30s)
|
|
IdleTimeout time.Duration // HTTP server idle timeout (default: 120s)
|
|
ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s)
|
|
SSEKeepAlive time.Duration // SSE keepalive interval (default: 15s, 0 to disable)
|
|
}
|
|
|
|
const (
|
|
// DefaultMaxRequestSize is the default maximum request body size (1MB).
|
|
DefaultMaxRequestSize = 1 << 20 // 1MB
|
|
|
|
// Default HTTP server timeouts
|
|
DefaultReadTimeout = 30 * time.Second
|
|
DefaultWriteTimeout = 30 * time.Second
|
|
DefaultIdleTimeout = 120 * time.Second
|
|
DefaultReadHeaderTimeout = 10 * time.Second
|
|
|
|
// DefaultSSEKeepAlive is the default interval for SSE keepalive messages.
|
|
// These are sent as SSE comments to keep the connection alive through
|
|
// proxies and load balancers, and to detect stale connections.
|
|
DefaultSSEKeepAlive = 15 * time.Second
|
|
)
|
|
|
|
// HTTPTransport implements the MCP Streamable HTTP transport.
|
|
type HTTPTransport struct {
|
|
server *Server
|
|
config HTTPConfig
|
|
sessions *SessionStore
|
|
}
|
|
|
|
// NewHTTPTransport creates a new HTTP transport.
|
|
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
|
if config.Address == "" {
|
|
config.Address = "127.0.0.1:8080"
|
|
}
|
|
if config.Endpoint == "" {
|
|
config.Endpoint = "/mcp"
|
|
}
|
|
if config.SessionTTL == 0 {
|
|
config.SessionTTL = 30 * time.Minute
|
|
}
|
|
if config.MaxSessions == 0 {
|
|
config.MaxSessions = DefaultMaxSessions
|
|
}
|
|
if config.MaxRequestSize == 0 {
|
|
config.MaxRequestSize = DefaultMaxRequestSize
|
|
}
|
|
if config.ReadTimeout == 0 {
|
|
config.ReadTimeout = DefaultReadTimeout
|
|
}
|
|
if config.WriteTimeout == 0 {
|
|
config.WriteTimeout = DefaultWriteTimeout
|
|
}
|
|
if config.IdleTimeout == 0 {
|
|
config.IdleTimeout = DefaultIdleTimeout
|
|
}
|
|
if config.ReadHeaderTimeout == 0 {
|
|
config.ReadHeaderTimeout = DefaultReadHeaderTimeout
|
|
}
|
|
// SSEKeepAlive: 0 means use default, negative means disabled
|
|
if config.SSEKeepAlive == 0 {
|
|
config.SSEKeepAlive = DefaultSSEKeepAlive
|
|
}
|
|
|
|
return &HTTPTransport{
|
|
server: server,
|
|
config: config,
|
|
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
|
|
}
|
|
}
|
|
|
|
// Run starts the HTTP server and blocks until the context is cancelled.
|
|
func (t *HTTPTransport) Run(ctx context.Context) error {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
|
|
|
|
httpServer := &http.Server{
|
|
Addr: t.config.Address,
|
|
Handler: mux,
|
|
ReadTimeout: t.config.ReadTimeout,
|
|
WriteTimeout: t.config.WriteTimeout,
|
|
IdleTimeout: t.config.IdleTimeout,
|
|
ReadHeaderTimeout: t.config.ReadHeaderTimeout,
|
|
BaseContext: func(l net.Listener) context.Context {
|
|
return ctx
|
|
},
|
|
}
|
|
|
|
// Graceful shutdown on context cancellation
|
|
go func() {
|
|
<-ctx.Done()
|
|
t.server.logger.Println("Shutting down HTTP server...")
|
|
t.sessions.Stop()
|
|
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
|
t.server.logger.Printf("HTTP server shutdown error: %v", err)
|
|
}
|
|
}()
|
|
|
|
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
|
|
|
|
var err error
|
|
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
|
|
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
|
|
} else {
|
|
err = httpServer.ListenAndServe()
|
|
}
|
|
|
|
if err == http.ErrServerClosed {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
// handleMCP routes requests based on HTTP method.
|
|
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
|
|
// Validate Origin header
|
|
if !t.isOriginAllowed(r) {
|
|
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
t.handlePost(w, r)
|
|
case http.MethodGet:
|
|
t.handleGet(w, r)
|
|
case http.MethodDelete:
|
|
t.handleDelete(w, r)
|
|
case http.MethodOptions:
|
|
t.handleOptions(w, r)
|
|
default:
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
}
|
|
}
|
|
|
|
// handlePost handles JSON-RPC requests.
|
|
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
|
// Limit request body size to prevent memory exhaustion attacks
|
|
r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize)
|
|
|
|
// Read request body
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
// Check if this is a size limit error
|
|
if err.Error() == "http: request body too large" {
|
|
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
|
return
|
|
}
|
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Parse request to check method
|
|
var req Request
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(Response{
|
|
JSONRPC: "2.0",
|
|
Error: &Error{
|
|
Code: ParseError,
|
|
Message: "Parse error",
|
|
Data: err.Error(),
|
|
},
|
|
})
|
|
return
|
|
}
|
|
|
|
// Handle initialize request - create session
|
|
if req.Method == MethodInitialize {
|
|
t.handleInitialize(w, r, &req)
|
|
return
|
|
}
|
|
|
|
// All other requests require a valid session
|
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
session := t.sessions.Get(sessionID)
|
|
if session == nil {
|
|
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Update session activity
|
|
session.Touch()
|
|
|
|
// Handle notifications (no response expected)
|
|
if req.Method == MethodInitialized {
|
|
session.SetInitialized()
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
// Check if this is a notification (no ID)
|
|
if req.ID == nil {
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
// Process the request
|
|
resp := t.server.HandleRequest(r.Context(), &req)
|
|
if resp == nil {
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
// handleInitialize handles the initialize request and creates a new session.
|
|
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
|
|
// Create a new session
|
|
session, err := t.sessions.Create()
|
|
if err != nil {
|
|
if err == ErrTooManySessions {
|
|
t.server.logger.Printf("Session limit reached")
|
|
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
t.server.logger.Printf("Failed to create session: %v", err)
|
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Process initialize request
|
|
resp := t.server.HandleRequest(r.Context(), req)
|
|
if resp == nil {
|
|
t.sessions.Delete(session.ID)
|
|
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// If initialize failed, clean up session
|
|
if resp.Error != nil {
|
|
t.sessions.Delete(session.ID)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Mcp-Session-Id", session.ID)
|
|
w.WriteHeader(http.StatusOK)
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
// handleGet handles SSE stream for server-initiated notifications.
|
|
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
|
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
session := t.sessions.Get(sessionID)
|
|
if session == nil {
|
|
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
// Check if client accepts SSE
|
|
accept := r.Header.Get("Accept")
|
|
if !strings.Contains(accept, "text/event-stream") {
|
|
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
|
|
return
|
|
}
|
|
|
|
// Set SSE headers
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Flush headers
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
|
|
// Use ResponseController to manage write deadlines for long-lived SSE connections
|
|
rc := http.NewResponseController(w)
|
|
|
|
// Set up keepalive ticker if enabled
|
|
var keepaliveTicker *time.Ticker
|
|
var keepaliveChan <-chan time.Time
|
|
if t.config.SSEKeepAlive > 0 {
|
|
keepaliveTicker = time.NewTicker(t.config.SSEKeepAlive)
|
|
keepaliveChan = keepaliveTicker.C
|
|
defer keepaliveTicker.Stop()
|
|
}
|
|
|
|
// Stream notifications
|
|
ctx := r.Context()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
|
|
case <-keepaliveChan:
|
|
// Send SSE comment as keepalive (ignored by clients)
|
|
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
|
t.server.logger.Printf("Failed to set write deadline: %v", err)
|
|
}
|
|
if _, err := fmt.Fprintf(w, ":keepalive\n\n"); err != nil {
|
|
// Write failed, connection likely closed
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
|
|
case notification, ok := <-session.Notifications():
|
|
if !ok {
|
|
// Session closed
|
|
return
|
|
}
|
|
|
|
// Extend write deadline before each write
|
|
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
|
t.server.logger.Printf("Failed to set write deadline: %v", err)
|
|
}
|
|
|
|
data, err := json.Marshal(notification)
|
|
if err != nil {
|
|
t.server.logger.Printf("Failed to marshal notification: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Write SSE event
|
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
|
|
// Write failed, connection likely closed
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
|
|
// Touch session to keep it alive
|
|
session.Touch()
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleDelete terminates a session.
|
|
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
|
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if t.sessions.Delete(sessionID) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
} else {
|
|
http.Error(w, "Session not found", http.StatusNotFound)
|
|
}
|
|
}
|
|
|
|
// handleOptions handles CORS preflight requests.
|
|
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if origin != "" && t.isOriginAllowed(r) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
|
|
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// isOriginAllowed checks if the request origin is allowed.
|
|
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
// No Origin header (same-origin request) is always allowed
|
|
if origin == "" {
|
|
return true
|
|
}
|
|
|
|
// If no allowed origins configured, only allow localhost
|
|
if len(t.config.AllowedOrigins) == 0 {
|
|
return isLocalhostOrigin(origin)
|
|
}
|
|
|
|
// Check against allowed origins
|
|
for _, allowed := range t.config.AllowedOrigins {
|
|
if allowed == "*" || allowed == origin {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// isLocalhostOrigin checks if the origin is a localhost address.
|
|
func isLocalhostOrigin(origin string) bool {
|
|
origin = strings.ToLower(origin)
|
|
|
|
// Check for localhost patterns (must be followed by :, /, or end of string)
|
|
localhostPatterns := []string{
|
|
"http://localhost",
|
|
"https://localhost",
|
|
"http://127.0.0.1",
|
|
"https://127.0.0.1",
|
|
"http://[::1]",
|
|
"https://[::1]",
|
|
}
|
|
|
|
for _, pattern := range localhostPatterns {
|
|
if origin == pattern {
|
|
return true
|
|
}
|
|
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|