diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go index 834e4fd..f35419b 100644 --- a/internal/mcp/transport_http.go +++ b/internal/mcp/transport_http.go @@ -13,18 +13,28 @@ import ( // HTTPConfig configures the HTTP transport. type HTTPConfig struct { - Address string // Listen address (e.g., "127.0.0.1:8080") - Endpoint string // MCP endpoint path (e.g., "/mcp") - AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) - SessionTTL time.Duration // Session TTL (default: 30 minutes) - TLSCertFile string // TLS certificate file (optional) - TLSKeyFile string // TLS key file (optional) - MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) + Address string // Listen address (e.g., "127.0.0.1:8080") + Endpoint string // MCP endpoint path (e.g., "/mcp") + AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) + SessionTTL time.Duration // Session TTL (default: 30 minutes) + TLSCertFile string // TLS certificate file (optional) + TLSKeyFile string // TLS key file (optional) + MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) + ReadTimeout time.Duration // HTTP server read timeout (default: 30s) + WriteTimeout time.Duration // HTTP server write timeout (default: 30s) + IdleTimeout time.Duration // HTTP server idle timeout (default: 120s) + ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s) } const ( // DefaultMaxRequestSize is the default maximum request body size (1MB). DefaultMaxRequestSize = 1 << 20 // 1MB + + // Default HTTP server timeouts + DefaultReadTimeout = 30 * time.Second + DefaultWriteTimeout = 30 * time.Second + DefaultIdleTimeout = 120 * time.Second + DefaultReadHeaderTimeout = 10 * time.Second ) // HTTPTransport implements the MCP Streamable HTTP transport. @@ -48,6 +58,18 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.MaxRequestSize == 0 { config.MaxRequestSize = DefaultMaxRequestSize } + if config.ReadTimeout == 0 { + config.ReadTimeout = DefaultReadTimeout + } + if config.WriteTimeout == 0 { + config.WriteTimeout = DefaultWriteTimeout + } + if config.IdleTimeout == 0 { + config.IdleTimeout = DefaultIdleTimeout + } + if config.ReadHeaderTimeout == 0 { + config.ReadHeaderTimeout = DefaultReadHeaderTimeout + } return &HTTPTransport{ server: server, @@ -62,8 +84,12 @@ func (t *HTTPTransport) Run(ctx context.Context) error { mux.HandleFunc(t.config.Endpoint, t.handleMCP) httpServer := &http.Server{ - Addr: t.config.Address, - Handler: mux, + Addr: t.config.Address, + Handler: mux, + ReadTimeout: t.config.ReadTimeout, + WriteTimeout: t.config.WriteTimeout, + IdleTimeout: t.config.IdleTimeout, + ReadHeaderTimeout: t.config.ReadHeaderTimeout, BaseContext: func(l net.Listener) context.Context { return ctx }, @@ -264,6 +290,9 @@ func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) { } flusher.Flush() + // Use ResponseController to manage write deadlines for long-lived SSE connections + rc := http.NewResponseController(w) + // Stream notifications ctx := r.Context() for { @@ -276,6 +305,11 @@ func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) { return } + // Extend write deadline before each write + if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil { + t.server.logger.Printf("Failed to set write deadline: %v", err) + } + data, err := json.Marshal(notification) if err != nil { t.server.logger.Printf("Failed to marshal notification: %v", err) diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go index f9354cf..68e746c 100644 --- a/internal/mcp/transport_http_test.go +++ b/internal/mcp/transport_http_test.go @@ -481,6 +481,82 @@ func TestHTTPTransportOptionsRequest(t *testing.T) { } } +func TestHTTPTransportDefaultConfig(t *testing.T) { + server := NewServer(nil, log.New(io.Discard, "", 0)) + transport := NewHTTPTransport(server, HTTPConfig{}) + + // Verify defaults are applied + if transport.config.Address != "127.0.0.1:8080" { + t.Errorf("Expected default address 127.0.0.1:8080, got %s", transport.config.Address) + } + if transport.config.Endpoint != "/mcp" { + t.Errorf("Expected default endpoint /mcp, got %s", transport.config.Endpoint) + } + if transport.config.SessionTTL != 30*time.Minute { + t.Errorf("Expected default session TTL 30m, got %v", transport.config.SessionTTL) + } + if transport.config.MaxRequestSize != DefaultMaxRequestSize { + t.Errorf("Expected default max request size %d, got %d", DefaultMaxRequestSize, transport.config.MaxRequestSize) + } + if transport.config.ReadTimeout != DefaultReadTimeout { + t.Errorf("Expected default read timeout %v, got %v", DefaultReadTimeout, transport.config.ReadTimeout) + } + if transport.config.WriteTimeout != DefaultWriteTimeout { + t.Errorf("Expected default write timeout %v, got %v", DefaultWriteTimeout, transport.config.WriteTimeout) + } + if transport.config.IdleTimeout != DefaultIdleTimeout { + t.Errorf("Expected default idle timeout %v, got %v", DefaultIdleTimeout, transport.config.IdleTimeout) + } + if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout { + t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout) + } + + transport.sessions.Stop() +} + +func TestHTTPTransportCustomConfig(t *testing.T) { + server := NewServer(nil, log.New(io.Discard, "", 0)) + config := HTTPConfig{ + Address: "0.0.0.0:9090", + Endpoint: "/api/mcp", + SessionTTL: 1 * time.Hour, + MaxRequestSize: 5 << 20, // 5MB + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 300 * time.Second, + ReadHeaderTimeout: 20 * time.Second, + } + transport := NewHTTPTransport(server, config) + + // Verify custom values are preserved + if transport.config.Address != "0.0.0.0:9090" { + t.Errorf("Expected custom address, got %s", transport.config.Address) + } + if transport.config.Endpoint != "/api/mcp" { + t.Errorf("Expected custom endpoint, got %s", transport.config.Endpoint) + } + if transport.config.SessionTTL != 1*time.Hour { + t.Errorf("Expected custom session TTL, got %v", transport.config.SessionTTL) + } + if transport.config.MaxRequestSize != 5<<20 { + t.Errorf("Expected custom max request size, got %d", transport.config.MaxRequestSize) + } + if transport.config.ReadTimeout != 60*time.Second { + t.Errorf("Expected custom read timeout, got %v", transport.config.ReadTimeout) + } + if transport.config.WriteTimeout != 60*time.Second { + t.Errorf("Expected custom write timeout, got %v", transport.config.WriteTimeout) + } + if transport.config.IdleTimeout != 300*time.Second { + t.Errorf("Expected custom idle timeout, got %v", transport.config.IdleTimeout) + } + if transport.config.ReadHeaderTimeout != 20*time.Second { + t.Errorf("Expected custom read header timeout, got %v", transport.config.ReadHeaderTimeout) + } + + transport.sessions.Stop() +} + func TestHTTPTransportRequestBodyTooLarge(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ MaxRequestSize: 100, // Very small limit for testing