feat: add Streamable HTTP transport support
Add support for running the MCP server over HTTP with Server-Sent Events (SSE) using the MCP Streamable HTTP specification, alongside the existing STDIO transport. New features: - Transport abstraction with Transport interface - HTTP transport with session management - SSE support for server-initiated notifications - CORS security with configurable allowed origins - Optional TLS support - CLI flags for HTTP configuration (--transport, --http-address, etc.) - NixOS module options for HTTP transport The HTTP transport implements: - POST /mcp: JSON-RPC requests with session management - GET /mcp: SSE stream for server notifications - DELETE /mcp: Session termination - Origin validation (localhost-only by default) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
354
internal/mcp/transport_http.go
Normal file
354
internal/mcp/transport_http.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPConfig configures the HTTP transport.
|
||||
type HTTPConfig struct {
|
||||
Address string // Listen address (e.g., "127.0.0.1:8080")
|
||||
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
||||
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
||||
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
||||
TLSCertFile string // TLS certificate file (optional)
|
||||
TLSKeyFile string // TLS key file (optional)
|
||||
}
|
||||
|
||||
// HTTPTransport implements the MCP Streamable HTTP transport.
|
||||
type HTTPTransport struct {
|
||||
server *Server
|
||||
config HTTPConfig
|
||||
sessions *SessionStore
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTP transport.
|
||||
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
||||
if config.Address == "" {
|
||||
config.Address = "127.0.0.1:8080"
|
||||
}
|
||||
if config.Endpoint == "" {
|
||||
config.Endpoint = "/mcp"
|
||||
}
|
||||
if config.SessionTTL == 0 {
|
||||
config.SessionTTL = 30 * time.Minute
|
||||
}
|
||||
|
||||
return &HTTPTransport{
|
||||
server: server,
|
||||
config: config,
|
||||
sessions: NewSessionStore(config.SessionTTL),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the HTTP server and blocks until the context is cancelled.
|
||||
func (t *HTTPTransport) Run(ctx context.Context) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: t.config.Address,
|
||||
Handler: mux,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
|
||||
// Graceful shutdown on context cancellation
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
t.server.logger.Println("Shutting down HTTP server...")
|
||||
t.sessions.Stop()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
t.server.logger.Printf("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
|
||||
|
||||
var err error
|
||||
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
|
||||
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
|
||||
} else {
|
||||
err = httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
if err == http.ErrServerClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMCP routes requests based on HTTP method.
|
||||
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate Origin header
|
||||
if !t.isOriginAllowed(r) {
|
||||
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
t.handlePost(w, r)
|
||||
case http.MethodGet:
|
||||
t.handleGet(w, r)
|
||||
case http.MethodDelete:
|
||||
t.handleDelete(w, r)
|
||||
case http.MethodOptions:
|
||||
t.handleOptions(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePost handles JSON-RPC requests.
|
||||
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request to check method
|
||||
var req Request
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
JSONRPC: "2.0",
|
||||
Error: &Error{
|
||||
Code: ParseError,
|
||||
Message: "Parse error",
|
||||
Data: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle initialize request - create session
|
||||
if req.Method == MethodInitialize {
|
||||
t.handleInitialize(w, r, &req)
|
||||
return
|
||||
}
|
||||
|
||||
// All other requests require a valid session
|
||||
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session := t.sessions.Get(sessionID)
|
||||
if session == nil {
|
||||
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Update session activity
|
||||
session.Touch()
|
||||
|
||||
// Handle notifications (no response expected)
|
||||
if req.Method == MethodInitialized {
|
||||
session.SetInitialized()
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a notification (no ID)
|
||||
if req.ID == nil {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the request
|
||||
resp := t.server.HandleRequest(r.Context(), &req)
|
||||
if resp == nil {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleInitialize handles the initialize request and creates a new session.
|
||||
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
|
||||
// Create a new session
|
||||
session, err := t.sessions.Create()
|
||||
if err != nil {
|
||||
t.server.logger.Printf("Failed to create session: %v", err)
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Process initialize request
|
||||
resp := t.server.HandleRequest(r.Context(), req)
|
||||
if resp == nil {
|
||||
t.sessions.Delete(session.ID)
|
||||
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// If initialize failed, clean up session
|
||||
if resp.Error != nil {
|
||||
t.sessions.Delete(session.ID)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Mcp-Session-Id", session.ID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleGet handles SSE stream for server-initiated notifications.
|
||||
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session := t.sessions.Get(sessionID)
|
||||
if session == nil {
|
||||
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client accepts SSE
|
||||
accept := r.Header.Get("Accept")
|
||||
if !strings.Contains(accept, "text/event-stream") {
|
||||
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
|
||||
return
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Flush headers
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Stream notifications
|
||||
ctx := r.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case notification, ok := <-session.Notifications():
|
||||
if !ok {
|
||||
// Session closed
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
t.server.logger.Printf("Failed to marshal notification: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Write SSE event
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
// Touch session to keep it alive
|
||||
session.Touch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleDelete terminates a session.
|
||||
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if t.sessions.Delete(sessionID) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
} else {
|
||||
http.Error(w, "Session not found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOptions handles CORS preflight requests.
|
||||
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" && t.isOriginAllowed(r) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if the request origin is allowed.
|
||||
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// No Origin header (same-origin request) is always allowed
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If no allowed origins configured, only allow localhost
|
||||
if len(t.config.AllowedOrigins) == 0 {
|
||||
return isLocalhostOrigin(origin)
|
||||
}
|
||||
|
||||
// Check against allowed origins
|
||||
for _, allowed := range t.config.AllowedOrigins {
|
||||
if allowed == "*" || allowed == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLocalhostOrigin checks if the origin is a localhost address.
|
||||
func isLocalhostOrigin(origin string) bool {
|
||||
origin = strings.ToLower(origin)
|
||||
|
||||
// Check for localhost patterns (must be followed by :, /, or end of string)
|
||||
localhostPatterns := []string{
|
||||
"http://localhost",
|
||||
"https://localhost",
|
||||
"http://127.0.0.1",
|
||||
"https://127.0.0.1",
|
||||
"http://[::1]",
|
||||
"https://[::1]",
|
||||
}
|
||||
|
||||
for _, pattern := range localhostPatterns {
|
||||
if origin == pattern {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user