This repository has been archived on 2026-03-10. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
labmcp/internal/mcp/transport_http.go
Torjus Håkestad 08f8b2cd83 feat: add SSE keepalive messages for connection health
Add configurable SSEKeepAlive interval (default: 15s) that sends SSE
comment lines (`:keepalive`) to maintain connection health.

Benefits:
- Keeps connections alive through proxies/load balancers that timeout
  idle connections
- Detects stale connections earlier (write failures terminate the
  handler)
- Standard SSE pattern - comments are ignored by compliant clients

Configuration:
- SSEKeepAlive > 0: send keepalives at specified interval
- SSEKeepAlive = 0: use default (15s)
- SSEKeepAlive < 0: disable keepalives

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:10:58 +01:00

449 lines
12 KiB
Go

package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
)
// HTTPConfig configures the HTTP transport.
type HTTPConfig struct {
Address string // Listen address (e.g., "127.0.0.1:8080")
Endpoint string // MCP endpoint path (e.g., "/mcp")
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
SessionTTL time.Duration // Session TTL (default: 30 minutes)
MaxSessions int // Maximum concurrent sessions (default: 10000)
TLSCertFile string // TLS certificate file (optional)
TLSKeyFile string // TLS key file (optional)
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
ReadTimeout time.Duration // HTTP server read timeout (default: 30s)
WriteTimeout time.Duration // HTTP server write timeout (default: 30s)
IdleTimeout time.Duration // HTTP server idle timeout (default: 120s)
ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s)
SSEKeepAlive time.Duration // SSE keepalive interval (default: 15s, 0 to disable)
}
const (
// DefaultMaxRequestSize is the default maximum request body size (1MB).
DefaultMaxRequestSize = 1 << 20 // 1MB
// Default HTTP server timeouts
DefaultReadTimeout = 30 * time.Second
DefaultWriteTimeout = 30 * time.Second
DefaultIdleTimeout = 120 * time.Second
DefaultReadHeaderTimeout = 10 * time.Second
// DefaultSSEKeepAlive is the default interval for SSE keepalive messages.
// These are sent as SSE comments to keep the connection alive through
// proxies and load balancers, and to detect stale connections.
DefaultSSEKeepAlive = 15 * time.Second
)
// HTTPTransport implements the MCP Streamable HTTP transport.
type HTTPTransport struct {
server *Server
config HTTPConfig
sessions *SessionStore
}
// NewHTTPTransport creates a new HTTP transport.
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
if config.Address == "" {
config.Address = "127.0.0.1:8080"
}
if config.Endpoint == "" {
config.Endpoint = "/mcp"
}
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
if config.MaxSessions == 0 {
config.MaxSessions = DefaultMaxSessions
}
if config.MaxRequestSize == 0 {
config.MaxRequestSize = DefaultMaxRequestSize
}
if config.ReadTimeout == 0 {
config.ReadTimeout = DefaultReadTimeout
}
if config.WriteTimeout == 0 {
config.WriteTimeout = DefaultWriteTimeout
}
if config.IdleTimeout == 0 {
config.IdleTimeout = DefaultIdleTimeout
}
if config.ReadHeaderTimeout == 0 {
config.ReadHeaderTimeout = DefaultReadHeaderTimeout
}
// SSEKeepAlive: 0 means use default, negative means disabled
if config.SSEKeepAlive == 0 {
config.SSEKeepAlive = DefaultSSEKeepAlive
}
return &HTTPTransport{
server: server,
config: config,
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
}
}
// Run starts the HTTP server and blocks until the context is cancelled.
func (t *HTTPTransport) Run(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
httpServer := &http.Server{
Addr: t.config.Address,
Handler: mux,
ReadTimeout: t.config.ReadTimeout,
WriteTimeout: t.config.WriteTimeout,
IdleTimeout: t.config.IdleTimeout,
ReadHeaderTimeout: t.config.ReadHeaderTimeout,
BaseContext: func(l net.Listener) context.Context {
return ctx
},
}
// Graceful shutdown on context cancellation
go func() {
<-ctx.Done()
t.server.logger.Println("Shutting down HTTP server...")
t.sessions.Stop()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
t.server.logger.Printf("HTTP server shutdown error: %v", err)
}
}()
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
var err error
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
} else {
err = httpServer.ListenAndServe()
}
if err == http.ErrServerClosed {
return nil
}
return err
}
// handleMCP routes requests based on HTTP method.
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
// Validate Origin header
if !t.isOriginAllowed(r) {
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
return
}
switch r.Method {
case http.MethodPost:
t.handlePost(w, r)
case http.MethodGet:
t.handleGet(w, r)
case http.MethodDelete:
t.handleDelete(w, r)
case http.MethodOptions:
t.handleOptions(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handlePost handles JSON-RPC requests.
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
// Limit request body size to prevent memory exhaustion attacks
r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize)
// Read request body
body, err := io.ReadAll(r.Body)
if err != nil {
// Check if this is a size limit error
if err.Error() == "http: request body too large" {
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
return
}
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
// Parse request to check method
var req Request
if err := json.Unmarshal(body, &req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
JSONRPC: "2.0",
Error: &Error{
Code: ParseError,
Message: "Parse error",
Data: err.Error(),
},
})
return
}
// Handle initialize request - create session
if req.Method == MethodInitialize {
t.handleInitialize(w, r, &req)
return
}
// All other requests require a valid session
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
session := t.sessions.Get(sessionID)
if session == nil {
http.Error(w, "Invalid or expired session", http.StatusNotFound)
return
}
// Update session activity
session.Touch()
// Handle notifications (no response expected)
if req.Method == MethodInitialized {
session.SetInitialized()
w.WriteHeader(http.StatusAccepted)
return
}
// Check if this is a notification (no ID)
if req.ID == nil {
w.WriteHeader(http.StatusAccepted)
return
}
// Process the request
resp := t.server.HandleRequest(r.Context(), &req)
if resp == nil {
w.WriteHeader(http.StatusAccepted)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
// handleInitialize handles the initialize request and creates a new session.
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
// Create a new session
session, err := t.sessions.Create()
if err != nil {
if err == ErrTooManySessions {
t.server.logger.Printf("Session limit reached")
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
return
}
t.server.logger.Printf("Failed to create session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
// Process initialize request
resp := t.server.HandleRequest(r.Context(), req)
if resp == nil {
t.sessions.Delete(session.ID)
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
return
}
// If initialize failed, clean up session
if resp.Error != nil {
t.sessions.Delete(session.ID)
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Mcp-Session-Id", session.ID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
// handleGet handles SSE stream for server-initiated notifications.
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
session := t.sessions.Get(sessionID)
if session == nil {
http.Error(w, "Invalid or expired session", http.StatusNotFound)
return
}
// Check if client accepts SSE
accept := r.Header.Get("Accept")
if !strings.Contains(accept, "text/event-stream") {
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
// Flush headers
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
flusher.Flush()
// Use ResponseController to manage write deadlines for long-lived SSE connections
rc := http.NewResponseController(w)
// Set up keepalive ticker if enabled
var keepaliveTicker *time.Ticker
var keepaliveChan <-chan time.Time
if t.config.SSEKeepAlive > 0 {
keepaliveTicker = time.NewTicker(t.config.SSEKeepAlive)
keepaliveChan = keepaliveTicker.C
defer keepaliveTicker.Stop()
}
// Stream notifications
ctx := r.Context()
for {
select {
case <-ctx.Done():
return
case <-keepaliveChan:
// Send SSE comment as keepalive (ignored by clients)
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
t.server.logger.Printf("Failed to set write deadline: %v", err)
}
if _, err := fmt.Fprintf(w, ":keepalive\n\n"); err != nil {
// Write failed, connection likely closed
return
}
flusher.Flush()
case notification, ok := <-session.Notifications():
if !ok {
// Session closed
return
}
// Extend write deadline before each write
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
t.server.logger.Printf("Failed to set write deadline: %v", err)
}
data, err := json.Marshal(notification)
if err != nil {
t.server.logger.Printf("Failed to marshal notification: %v", err)
continue
}
// Write SSE event
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
// Write failed, connection likely closed
return
}
flusher.Flush()
// Touch session to keep it alive
session.Touch()
}
}
}
// handleDelete terminates a session.
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
if t.sessions.Delete(sessionID) {
w.WriteHeader(http.StatusNoContent)
} else {
http.Error(w, "Session not found", http.StatusNotFound)
}
}
// handleOptions handles CORS preflight requests.
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && t.isOriginAllowed(r) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
w.Header().Set("Access-Control-Max-Age", "86400")
}
w.WriteHeader(http.StatusNoContent)
}
// isOriginAllowed checks if the request origin is allowed.
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
origin := r.Header.Get("Origin")
// No Origin header (same-origin request) is always allowed
if origin == "" {
return true
}
// If no allowed origins configured, only allow localhost
if len(t.config.AllowedOrigins) == 0 {
return isLocalhostOrigin(origin)
}
// Check against allowed origins
for _, allowed := range t.config.AllowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
}
return false
}
// isLocalhostOrigin checks if the origin is a localhost address.
func isLocalhostOrigin(origin string) bool {
origin = strings.ToLower(origin)
// Check for localhost patterns (must be followed by :, /, or end of string)
localhostPatterns := []string{
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
}
for _, pattern := range localhostPatterns {
if origin == pattern {
return true
}
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
return true
}
}
return false
}