This repository has been archived on 2026-03-10. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
labmcp/internal/mcp/transport_http_test.go
Torjus Håkestad 08f8b2cd83 feat: add SSE keepalive messages for connection health
Add configurable SSEKeepAlive interval (default: 15s) that sends SSE
comment lines (`:keepalive`) to maintain connection health.

Benefits:
- Keeps connections alive through proxies/load balancers that timeout
  idle connections
- Detects stale connections earlier (write failures terminate the
  handler)
- Standard SSE pattern - comments are ignored by compliant clients

Configuration:
- SSEKeepAlive > 0: send keepalives at specified interval
- SSEKeepAlive = 0: use default (15s)
- SSEKeepAlive < 0: disable keepalives

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:10:58 +01:00

753 lines
21 KiB
Go

package mcp
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// testHTTPTransport creates a transport with a test server
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
// Use a mock store
server := NewServer(nil, log.New(io.Discard, "", 0))
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
transport := NewHTTPTransport(server, config)
// Create test server
mux := http.NewServeMux()
endpoint := config.Endpoint
if endpoint == "" {
endpoint = "/mcp"
}
mux.HandleFunc(endpoint, transport.handleMCP)
ts := httptest.NewServer(mux)
t.Cleanup(func() {
ts.Close()
transport.sessions.Stop()
})
return transport, ts
}
func TestHTTPTransportInitialize(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send initialize request
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200, got %d", resp.StatusCode)
}
// Check session ID header
sessionID := resp.Header.Get("Mcp-Session-Id")
if sessionID == "" {
t.Error("Expected Mcp-Session-Id header")
}
if len(sessionID) != 32 {
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
}
// Check response body
var initResp Response
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if initResp.Error != nil {
t.Errorf("Initialize failed: %v", initResp.Error)
}
}
func TestHTTPTransportSessionRequired(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send tools/list without session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportInvalidSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send request with invalid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportValidSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
// Create session manually
session, _ := transport.sessions.Create()
// Send tools/list with valid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportNotificationAccepted(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Send notification (no ID)
notification := Request{
JSONRPC: "2.0",
Method: MethodInitialized,
}
body, _ := json.Marshal(notification)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
}
// Verify session is marked as initialized
if !session.IsInitialized() {
t.Error("Session should be marked as initialized")
}
}
func TestHTTPTransportDeleteSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Delete session
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
}
// Verify session is gone
if transport.sessions.Get(session.ID) != nil {
t.Error("Session should be deleted")
}
}
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", "nonexistent")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOriginValidation(t *testing.T) {
tests := []struct {
name string
allowedOrigins []string
origin string
expectAllowed bool
}{
{
name: "no origin header",
allowedOrigins: nil,
origin: "",
expectAllowed: true,
},
{
name: "localhost allowed by default",
allowedOrigins: nil,
origin: "http://localhost:3000",
expectAllowed: true,
},
{
name: "127.0.0.1 allowed by default",
allowedOrigins: nil,
origin: "http://127.0.0.1:8080",
expectAllowed: true,
},
{
name: "external origin blocked by default",
allowedOrigins: nil,
origin: "http://evil.com",
expectAllowed: false,
},
{
name: "explicit allow",
allowedOrigins: []string{"http://example.com"},
origin: "http://example.com",
expectAllowed: true,
},
{
name: "explicit allow wildcard",
allowedOrigins: []string{"*"},
origin: "http://anything.com",
expectAllowed: true,
},
{
name: "not in allowed list",
allowedOrigins: []string{"http://example.com"},
origin: "http://other.com",
expectAllowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: tt.allowedOrigins,
})
// Use initialize since it doesn't require a session
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
t.Error("Expected request to be allowed but was forbidden")
}
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
}
})
}
}
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// GET without Accept: text/event-stream
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotAcceptable {
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
}
}
func TestHTTPTransportSSEStream(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Start SSE stream in goroutine
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected 200, got %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
}
// Send a notification
notification := &Response{
JSONRPC: "2.0",
ID: 42,
Result: map[string]string{"test": "data"},
}
session.SendNotification(notification)
// Read the SSE event
buf := make([]byte, 1024)
n, err := resp.Body.Read(buf)
if err != nil && err != io.EOF {
t.Fatalf("Failed to read SSE event: %v", err)
}
data := string(buf[:n])
if !strings.HasPrefix(data, "data: ") {
t.Errorf("Expected SSE data event, got: %s", data)
}
// Parse the JSON from the SSE event
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
var received Response
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
t.Fatalf("Failed to parse notification JSON: %v", err)
}
// JSON unmarshal converts numbers to float64, so compare as float64
receivedID, ok := received.ID.(float64)
if !ok {
t.Fatalf("Expected numeric ID, got %T", received.ID)
}
if int(receivedID) != 42 {
t.Errorf("Expected notification ID 42, got %v", receivedID)
}
}
func TestHTTPTransportSSEKeepalive(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{
SSEKeepAlive: 50 * time.Millisecond, // Short interval for testing
})
session, _ := transport.sessions.Create()
// Start SSE stream
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected 200, got %d", resp.StatusCode)
}
// Read with timeout - should receive keepalive within 100ms
buf := make([]byte, 256)
done := make(chan struct{})
var readData string
var readErr error
go func() {
n, err := resp.Body.Read(buf)
readData = string(buf[:n])
readErr = err
close(done)
}()
select {
case <-done:
if readErr != nil && readErr.Error() != "EOF" {
t.Fatalf("Read error: %v", readErr)
}
// Should receive SSE comment keepalive
if !strings.Contains(readData, ":keepalive") {
t.Errorf("Expected keepalive comment, got: %q", readData)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timeout waiting for keepalive")
}
}
func TestHTTPTransportSSEKeepaliveDisabled(t *testing.T) {
server := NewServer(nil, log.New(io.Discard, "", 0))
config := HTTPConfig{
SSEKeepAlive: -1, // Explicitly disabled
}
transport := NewHTTPTransport(server, config)
defer transport.sessions.Stop()
// When SSEKeepAlive is negative, it should remain negative (disabled)
if transport.config.SSEKeepAlive != -1 {
t.Errorf("Expected SSEKeepAlive to remain -1 (disabled), got %v", transport.config.SSEKeepAlive)
}
}
func TestHTTPTransportParseError(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send invalid JSON
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
}
var jsonResp Response
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if jsonResp.Error == nil {
t.Error("Expected JSON-RPC error for parse error")
}
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
}
}
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected 405, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOptionsRequest(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: []string{"http://example.com"},
})
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
req.Header.Set("Origin", "http://example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204, got %d", resp.StatusCode)
}
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Error("Expected CORS origin header")
}
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
t.Error("Expected CORS methods header")
}
}
func TestHTTPTransportDefaultConfig(t *testing.T) {
server := NewServer(nil, log.New(io.Discard, "", 0))
transport := NewHTTPTransport(server, HTTPConfig{})
// Verify defaults are applied
if transport.config.Address != "127.0.0.1:8080" {
t.Errorf("Expected default address 127.0.0.1:8080, got %s", transport.config.Address)
}
if transport.config.Endpoint != "/mcp" {
t.Errorf("Expected default endpoint /mcp, got %s", transport.config.Endpoint)
}
if transport.config.SessionTTL != 30*time.Minute {
t.Errorf("Expected default session TTL 30m, got %v", transport.config.SessionTTL)
}
if transport.config.MaxRequestSize != DefaultMaxRequestSize {
t.Errorf("Expected default max request size %d, got %d", DefaultMaxRequestSize, transport.config.MaxRequestSize)
}
if transport.config.ReadTimeout != DefaultReadTimeout {
t.Errorf("Expected default read timeout %v, got %v", DefaultReadTimeout, transport.config.ReadTimeout)
}
if transport.config.WriteTimeout != DefaultWriteTimeout {
t.Errorf("Expected default write timeout %v, got %v", DefaultWriteTimeout, transport.config.WriteTimeout)
}
if transport.config.IdleTimeout != DefaultIdleTimeout {
t.Errorf("Expected default idle timeout %v, got %v", DefaultIdleTimeout, transport.config.IdleTimeout)
}
if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout {
t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout)
}
if transport.config.SSEKeepAlive != DefaultSSEKeepAlive {
t.Errorf("Expected default SSE keepalive %v, got %v", DefaultSSEKeepAlive, transport.config.SSEKeepAlive)
}
transport.sessions.Stop()
}
func TestHTTPTransportCustomConfig(t *testing.T) {
server := NewServer(nil, log.New(io.Discard, "", 0))
config := HTTPConfig{
Address: "0.0.0.0:9090",
Endpoint: "/api/mcp",
SessionTTL: 1 * time.Hour,
MaxRequestSize: 5 << 20, // 5MB
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 300 * time.Second,
ReadHeaderTimeout: 20 * time.Second,
}
transport := NewHTTPTransport(server, config)
// Verify custom values are preserved
if transport.config.Address != "0.0.0.0:9090" {
t.Errorf("Expected custom address, got %s", transport.config.Address)
}
if transport.config.Endpoint != "/api/mcp" {
t.Errorf("Expected custom endpoint, got %s", transport.config.Endpoint)
}
if transport.config.SessionTTL != 1*time.Hour {
t.Errorf("Expected custom session TTL, got %v", transport.config.SessionTTL)
}
if transport.config.MaxRequestSize != 5<<20 {
t.Errorf("Expected custom max request size, got %d", transport.config.MaxRequestSize)
}
if transport.config.ReadTimeout != 60*time.Second {
t.Errorf("Expected custom read timeout, got %v", transport.config.ReadTimeout)
}
if transport.config.WriteTimeout != 60*time.Second {
t.Errorf("Expected custom write timeout, got %v", transport.config.WriteTimeout)
}
if transport.config.IdleTimeout != 300*time.Second {
t.Errorf("Expected custom idle timeout, got %v", transport.config.IdleTimeout)
}
if transport.config.ReadHeaderTimeout != 20*time.Second {
t.Errorf("Expected custom read header timeout, got %v", transport.config.ReadHeaderTimeout)
}
transport.sessions.Stop()
}
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 100, // Very small limit for testing
})
// Create a request body larger than the limit
largeBody := make([]byte, 200)
for i := range largeBody {
largeBody[i] = 'x'
}
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusRequestEntityTooLarge {
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
}
}
func TestHTTPTransportSessionLimitReached(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxSessions: 2, // Very low limit for testing
})
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
// Create sessions up to the limit
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
}
}
// Third request should fail with 503
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
}
}
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 10000, // Reasonable limit
})
// Send initialize request (should be well within limit)
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
}
}
func TestIsLocalhostOrigin(t *testing.T) {
tests := []struct {
origin string
expected bool
}{
{"http://localhost", true},
{"http://localhost:3000", true},
{"https://localhost", true},
{"https://localhost:8443", true},
{"http://127.0.0.1", true},
{"http://127.0.0.1:8080", true},
{"https://127.0.0.1", true},
{"http://[::1]", true},
{"http://[::1]:8080", true},
{"https://[::1]", true},
{"http://example.com", false},
{"https://example.com", false},
{"http://localhost.evil.com", false},
{"http://192.168.1.1", false},
}
for _, tt := range tests {
t.Run(tt.origin, func(t *testing.T) {
result := isLocalhostOrigin(tt.origin)
if result != tt.expected {
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
}
})
}
}