diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go index 3120573..802191f 100644 --- a/internal/mcp/transport_http.go +++ b/internal/mcp/transport_http.go @@ -25,6 +25,7 @@ type HTTPConfig struct { WriteTimeout time.Duration // HTTP server write timeout (default: 30s) IdleTimeout time.Duration // HTTP server idle timeout (default: 120s) ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s) + SSEKeepAlive time.Duration // SSE keepalive interval (default: 15s, 0 to disable) } const ( @@ -36,6 +37,11 @@ const ( DefaultWriteTimeout = 30 * time.Second DefaultIdleTimeout = 120 * time.Second DefaultReadHeaderTimeout = 10 * time.Second + + // DefaultSSEKeepAlive is the default interval for SSE keepalive messages. + // These are sent as SSE comments to keep the connection alive through + // proxies and load balancers, and to detect stale connections. + DefaultSSEKeepAlive = 15 * time.Second ) // HTTPTransport implements the MCP Streamable HTTP transport. @@ -74,6 +80,10 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.ReadHeaderTimeout == 0 { config.ReadHeaderTimeout = DefaultReadHeaderTimeout } + // SSEKeepAlive: 0 means use default, negative means disabled + if config.SSEKeepAlive == 0 { + config.SSEKeepAlive = DefaultSSEKeepAlive + } return &HTTPTransport{ server: server, @@ -302,12 +312,33 @@ func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) { // Use ResponseController to manage write deadlines for long-lived SSE connections rc := http.NewResponseController(w) + // Set up keepalive ticker if enabled + var keepaliveTicker *time.Ticker + var keepaliveChan <-chan time.Time + if t.config.SSEKeepAlive > 0 { + keepaliveTicker = time.NewTicker(t.config.SSEKeepAlive) + keepaliveChan = keepaliveTicker.C + defer keepaliveTicker.Stop() + } + // Stream notifications ctx := r.Context() for { select { case <-ctx.Done(): return + + case <-keepaliveChan: + // Send SSE comment as keepalive (ignored by clients) + if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil { + t.server.logger.Printf("Failed to set write deadline: %v", err) + } + if _, err := fmt.Fprintf(w, ":keepalive\n\n"); err != nil { + // Write failed, connection likely closed + return + } + flusher.Flush() + case notification, ok := <-session.Notifications(): if !ok { // Session closed @@ -326,7 +357,10 @@ func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) { } // Write SSE event - fmt.Fprintf(w, "data: %s\n\n", data) + if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { + // Write failed, connection likely closed + return + } flusher.Flush() // Touch session to keep it alive diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go index 5f01b0b..9f5f304 100644 --- a/internal/mcp/transport_http_test.go +++ b/internal/mcp/transport_http_test.go @@ -409,6 +409,69 @@ func TestHTTPTransportSSEStream(t *testing.T) { } } +func TestHTTPTransportSSEKeepalive(t *testing.T) { + transport, ts := testHTTPTransport(t, HTTPConfig{ + SSEKeepAlive: 50 * time.Millisecond, // Short interval for testing + }) + + session, _ := transport.sessions.Create() + + // Start SSE stream + req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil) + req.Header.Set("Mcp-Session-Id", session.ID) + req.Header.Set("Accept", "text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200, got %d", resp.StatusCode) + } + + // Read with timeout - should receive keepalive within 100ms + buf := make([]byte, 256) + done := make(chan struct{}) + var readData string + var readErr error + + go func() { + n, err := resp.Body.Read(buf) + readData = string(buf[:n]) + readErr = err + close(done) + }() + + select { + case <-done: + if readErr != nil && readErr.Error() != "EOF" { + t.Fatalf("Read error: %v", readErr) + } + // Should receive SSE comment keepalive + if !strings.Contains(readData, ":keepalive") { + t.Errorf("Expected keepalive comment, got: %q", readData) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timeout waiting for keepalive") + } +} + +func TestHTTPTransportSSEKeepaliveDisabled(t *testing.T) { + server := NewServer(nil, log.New(io.Discard, "", 0)) + config := HTTPConfig{ + SSEKeepAlive: -1, // Explicitly disabled + } + transport := NewHTTPTransport(server, config) + defer transport.sessions.Stop() + + // When SSEKeepAlive is negative, it should remain negative (disabled) + if transport.config.SSEKeepAlive != -1 { + t.Errorf("Expected SSEKeepAlive to remain -1 (disabled), got %v", transport.config.SSEKeepAlive) + } +} + func TestHTTPTransportParseError(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) @@ -510,6 +573,9 @@ func TestHTTPTransportDefaultConfig(t *testing.T) { if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout { t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout) } + if transport.config.SSEKeepAlive != DefaultSSEKeepAlive { + t.Errorf("Expected default SSE keepalive %v, got %v", DefaultSSEKeepAlive, transport.config.SSEKeepAlive) + } transport.sessions.Stop() }