package mcp import ( "context" "encoding/json" "fmt" "io" "net" "net/http" "strings" "time" ) // HTTPConfig configures the HTTP transport. type HTTPConfig struct { Address string // Listen address (e.g., "127.0.0.1:8080") Endpoint string // MCP endpoint path (e.g., "/mcp") AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) SessionTTL time.Duration // Session TTL (default: 30 minutes) MaxSessions int // Maximum concurrent sessions (default: 10000) TLSCertFile string // TLS certificate file (optional) TLSKeyFile string // TLS key file (optional) MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) ReadTimeout time.Duration // HTTP server read timeout (default: 30s) WriteTimeout time.Duration // HTTP server write timeout (default: 30s) IdleTimeout time.Duration // HTTP server idle timeout (default: 120s) ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s) SSEKeepAlive time.Duration // SSE keepalive interval (default: 15s, 0 to disable) } const ( // DefaultMaxRequestSize is the default maximum request body size (1MB). DefaultMaxRequestSize = 1 << 20 // 1MB // Default HTTP server timeouts DefaultReadTimeout = 30 * time.Second DefaultWriteTimeout = 30 * time.Second DefaultIdleTimeout = 120 * time.Second DefaultReadHeaderTimeout = 10 * time.Second // DefaultSSEKeepAlive is the default interval for SSE keepalive messages. // These are sent as SSE comments to keep the connection alive through // proxies and load balancers, and to detect stale connections. DefaultSSEKeepAlive = 15 * time.Second ) // HTTPTransport implements the MCP Streamable HTTP transport. type HTTPTransport struct { server *Server config HTTPConfig sessions *SessionStore } // NewHTTPTransport creates a new HTTP transport. func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.Address == "" { config.Address = "127.0.0.1:8080" } if config.Endpoint == "" { config.Endpoint = "/mcp" } if config.SessionTTL == 0 { config.SessionTTL = 30 * time.Minute } if config.MaxSessions == 0 { config.MaxSessions = DefaultMaxSessions } if config.MaxRequestSize == 0 { config.MaxRequestSize = DefaultMaxRequestSize } if config.ReadTimeout == 0 { config.ReadTimeout = DefaultReadTimeout } if config.WriteTimeout == 0 { config.WriteTimeout = DefaultWriteTimeout } if config.IdleTimeout == 0 { config.IdleTimeout = DefaultIdleTimeout } if config.ReadHeaderTimeout == 0 { config.ReadHeaderTimeout = DefaultReadHeaderTimeout } // SSEKeepAlive: 0 means use default, negative means disabled if config.SSEKeepAlive == 0 { config.SSEKeepAlive = DefaultSSEKeepAlive } return &HTTPTransport{ server: server, config: config, sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions), } } // Run starts the HTTP server and blocks until the context is cancelled. func (t *HTTPTransport) Run(ctx context.Context) error { mux := http.NewServeMux() mux.HandleFunc(t.config.Endpoint, t.handleMCP) httpServer := &http.Server{ Addr: t.config.Address, Handler: mux, ReadTimeout: t.config.ReadTimeout, WriteTimeout: t.config.WriteTimeout, IdleTimeout: t.config.IdleTimeout, ReadHeaderTimeout: t.config.ReadHeaderTimeout, BaseContext: func(l net.Listener) context.Context { return ctx }, } // Graceful shutdown on context cancellation go func() { <-ctx.Done() t.server.logger.Println("Shutting down HTTP server...") t.sessions.Stop() shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := httpServer.Shutdown(shutdownCtx); err != nil { t.server.logger.Printf("HTTP server shutdown error: %v", err) } }() t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint) var err error if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" { err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile) } else { err = httpServer.ListenAndServe() } if err == http.ErrServerClosed { return nil } return err } // handleMCP routes requests based on HTTP method. func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) { // Validate Origin header if !t.isOriginAllowed(r) { http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden) return } switch r.Method { case http.MethodPost: t.handlePost(w, r) case http.MethodGet: t.handleGet(w, r) case http.MethodDelete: t.handleDelete(w, r) case http.MethodOptions: t.handleOptions(w, r) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } // handlePost handles JSON-RPC requests. func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) { // Limit request body size to prevent memory exhaustion attacks r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize) // Read request body body, err := io.ReadAll(r.Body) if err != nil { // Check if this is a size limit error if err.Error() == "http: request body too large" { http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) return } http.Error(w, "Failed to read request body", http.StatusBadRequest) return } // Parse request to check method var req Request if err := json.Unmarshal(body, &req); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) //nolint:errcheck // response already being written, can't handle encode error json.NewEncoder(w).Encode(Response{ JSONRPC: "2.0", Error: &Error{ Code: ParseError, Message: "Parse error", Data: err.Error(), }, }) return } // Handle initialize request - create session if req.Method == MethodInitialize { t.handleInitialize(w, r, &req) return } // All other requests require a valid session sessionID := r.Header.Get("Mcp-Session-Id") if sessionID == "" { http.Error(w, "Session ID required", http.StatusBadRequest) return } session := t.sessions.Get(sessionID) if session == nil { http.Error(w, "Invalid or expired session", http.StatusNotFound) return } // Update session activity session.Touch() // Handle notifications (no response expected) if req.Method == MethodInitialized { session.SetInitialized() w.WriteHeader(http.StatusAccepted) return } // Check if this is a notification (no ID) if req.ID == nil { w.WriteHeader(http.StatusAccepted) return } // Process the request resp := t.server.HandleRequest(r.Context(), &req) if resp == nil { w.WriteHeader(http.StatusAccepted) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(resp) //nolint:errcheck // response already being written } // handleInitialize handles the initialize request and creates a new session. func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) { // Create a new session session, err := t.sessions.Create() if err != nil { if err == ErrTooManySessions { t.server.logger.Printf("Session limit reached") http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable) return } t.server.logger.Printf("Failed to create session: %v", err) http.Error(w, "Failed to create session", http.StatusInternalServerError) return } // Process initialize request resp := t.server.HandleRequest(r.Context(), req) if resp == nil { t.sessions.Delete(session.ID) http.Error(w, "Initialize failed: no response", http.StatusInternalServerError) return } // If initialize failed, clean up session if resp.Error != nil { t.sessions.Delete(session.ID) } w.Header().Set("Content-Type", "application/json") w.Header().Set("Mcp-Session-Id", session.ID) w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(resp) //nolint:errcheck // response already being written } // handleGet handles SSE stream for server-initiated notifications. func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) { sessionID := r.Header.Get("Mcp-Session-Id") if sessionID == "" { http.Error(w, "Session ID required", http.StatusBadRequest) return } session := t.sessions.Get(sessionID) if session == nil { http.Error(w, "Invalid or expired session", http.StatusNotFound) return } // Check if client accepts SSE accept := r.Header.Get("Accept") if !strings.Contains(accept, "text/event-stream") { http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable) return } // Set SSE headers w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.WriteHeader(http.StatusOK) // Flush headers flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming not supported", http.StatusInternalServerError) return } flusher.Flush() // Use ResponseController to manage write deadlines for long-lived SSE connections rc := http.NewResponseController(w) // Set up keepalive ticker if enabled var keepaliveTicker *time.Ticker var keepaliveChan <-chan time.Time if t.config.SSEKeepAlive > 0 { keepaliveTicker = time.NewTicker(t.config.SSEKeepAlive) keepaliveChan = keepaliveTicker.C defer keepaliveTicker.Stop() } // Stream notifications ctx := r.Context() for { select { case <-ctx.Done(): return case <-keepaliveChan: // Send SSE comment as keepalive (ignored by clients) if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil { t.server.logger.Printf("Failed to set write deadline: %v", err) } if _, err := fmt.Fprintf(w, ":keepalive\n\n"); err != nil { // Write failed, connection likely closed return } flusher.Flush() case notification, ok := <-session.Notifications(): if !ok { // Session closed return } // Extend write deadline before each write if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil { t.server.logger.Printf("Failed to set write deadline: %v", err) } data, err := json.Marshal(notification) if err != nil { t.server.logger.Printf("Failed to marshal notification: %v", err) continue } // Write SSE event if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { // Write failed, connection likely closed return } flusher.Flush() // Touch session to keep it alive session.Touch() } } } // handleDelete terminates a session. func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) { sessionID := r.Header.Get("Mcp-Session-Id") if sessionID == "" { http.Error(w, "Session ID required", http.StatusBadRequest) return } if t.sessions.Delete(sessionID) { w.WriteHeader(http.StatusNoContent) } else { http.Error(w, "Session not found", http.StatusNotFound) } } // handleOptions handles CORS preflight requests. func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" && t.isOriginAllowed(r) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id") w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id") w.Header().Set("Access-Control-Max-Age", "86400") } w.WriteHeader(http.StatusNoContent) } // isOriginAllowed checks if the request origin is allowed. func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool { origin := r.Header.Get("Origin") // No Origin header (same-origin request) is always allowed if origin == "" { return true } // If no allowed origins configured, only allow localhost if len(t.config.AllowedOrigins) == 0 { return isLocalhostOrigin(origin) } // Check against allowed origins for _, allowed := range t.config.AllowedOrigins { if allowed == "*" || allowed == origin { return true } } return false } // isLocalhostOrigin checks if the origin is a localhost address. func isLocalhostOrigin(origin string) bool { origin = strings.ToLower(origin) // Check for localhost patterns (must be followed by :, /, or end of string) localhostPatterns := []string{ "http://localhost", "https://localhost", "http://127.0.0.1", "https://127.0.0.1", "http://[::1]", "https://[::1]", } for _, pattern := range localhostPatterns { if origin == pattern { return true } if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") { return true } } return false }