package mcp import ( "bytes" "encoding/json" "io" "log" "net/http" "net/http/httptest" "strings" "testing" "time" ) // testHTTPTransport creates a transport with a test server func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) { // Use a mock store server := NewServer(nil, log.New(io.Discard, "", 0)) if config.SessionTTL == 0 { config.SessionTTL = 30 * time.Minute } transport := NewHTTPTransport(server, config) // Create test server mux := http.NewServeMux() endpoint := config.Endpoint if endpoint == "" { endpoint = "/mcp" } mux.HandleFunc(endpoint, transport.handleMCP) ts := httptest.NewServer(mux) t.Cleanup(func() { ts.Close() transport.sessions.Stop() }) return transport, ts } func TestHTTPTransportInitialize(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) // Send initialize request initReq := Request{ JSONRPC: "2.0", ID: 1, Method: MethodInitialize, Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), } body, _ := json.Marshal(initReq) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected 200, got %d", resp.StatusCode) } // Check session ID header sessionID := resp.Header.Get("Mcp-Session-Id") if sessionID == "" { t.Error("Expected Mcp-Session-Id header") } if len(sessionID) != 32 { t.Errorf("Session ID should be 32 chars, got %d", len(sessionID)) } // Check response body var initResp Response if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil { t.Fatalf("Failed to decode response: %v", err) } if initResp.Error != nil { t.Errorf("Initialize failed: %v", initResp.Error) } } func TestHTTPTransportSessionRequired(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) // Send tools/list without session listReq := Request{ JSONRPC: "2.0", ID: 1, Method: MethodToolsList, } body, _ := json.Marshal(listReq) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadRequest { t.Errorf("Expected 400 without session, got %d", resp.StatusCode) } } func TestHTTPTransportInvalidSession(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) // Send request with invalid session listReq := Request{ JSONRPC: "2.0", ID: 1, Method: MethodToolsList, } body, _ := json.Marshal(listReq) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Mcp-Session-Id", "invalid-session-id") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode) } } func TestHTTPTransportValidSession(t *testing.T) { transport, ts := testHTTPTransport(t, HTTPConfig{}) // Create session manually session, _ := transport.sessions.Create() // Send tools/list with valid session listReq := Request{ JSONRPC: "2.0", ID: 1, Method: MethodToolsList, } body, _ := json.Marshal(listReq) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Mcp-Session-Id", session.ID) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode) } } func TestHTTPTransportNotificationAccepted(t *testing.T) { transport, ts := testHTTPTransport(t, HTTPConfig{}) session, _ := transport.sessions.Create() // Send notification (no ID) notification := Request{ JSONRPC: "2.0", Method: MethodInitialized, } body, _ := json.Marshal(notification) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Mcp-Session-Id", session.ID) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusAccepted { t.Errorf("Expected 202 for notification, got %d", resp.StatusCode) } // Verify session is marked as initialized if !session.IsInitialized() { t.Error("Session should be marked as initialized") } } func TestHTTPTransportDeleteSession(t *testing.T) { transport, ts := testHTTPTransport(t, HTTPConfig{}) session, _ := transport.sessions.Create() // Delete session req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil) req.Header.Set("Mcp-Session-Id", session.ID) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNoContent { t.Errorf("Expected 204 for delete, got %d", resp.StatusCode) } // Verify session is gone if transport.sessions.Get(session.ID) != nil { t.Error("Session should be deleted") } } func TestHTTPTransportDeleteNonexistentSession(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil) req.Header.Set("Mcp-Session-Id", "nonexistent") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode) } } func TestHTTPTransportOriginValidation(t *testing.T) { tests := []struct { name string allowedOrigins []string origin string expectAllowed bool }{ { name: "no origin header", allowedOrigins: nil, origin: "", expectAllowed: true, }, { name: "localhost allowed by default", allowedOrigins: nil, origin: "http://localhost:3000", expectAllowed: true, }, { name: "127.0.0.1 allowed by default", allowedOrigins: nil, origin: "http://127.0.0.1:8080", expectAllowed: true, }, { name: "external origin blocked by default", allowedOrigins: nil, origin: "http://evil.com", expectAllowed: false, }, { name: "explicit allow", allowedOrigins: []string{"http://example.com"}, origin: "http://example.com", expectAllowed: true, }, { name: "explicit allow wildcard", allowedOrigins: []string{"*"}, origin: "http://anything.com", expectAllowed: true, }, { name: "not in allowed list", allowedOrigins: []string{"http://example.com"}, origin: "http://other.com", expectAllowed: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ AllowedOrigins: tt.allowedOrigins, }) // Use initialize since it doesn't require a session initReq := Request{ JSONRPC: "2.0", ID: 1, Method: MethodInitialize, Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), } body, _ := json.Marshal(initReq) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") if tt.origin != "" { req.Header.Set("Origin", tt.origin) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if tt.expectAllowed && resp.StatusCode == http.StatusForbidden { t.Error("Expected request to be allowed but was forbidden") } if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden { t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode) } }) } } func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) { transport, ts := testHTTPTransport(t, HTTPConfig{}) session, _ := transport.sessions.Create() // GET without Accept: text/event-stream req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil) req.Header.Set("Mcp-Session-Id", session.ID) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotAcceptable { t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode) } } func TestHTTPTransportSSEStream(t *testing.T) { transport, ts := testHTTPTransport(t, HTTPConfig{}) session, _ := transport.sessions.Create() // Start SSE stream in goroutine req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil) req.Header.Set("Mcp-Session-Id", session.ID) req.Header.Set("Accept", "text/event-stream") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("Expected 200, got %d", resp.StatusCode) } contentType := resp.Header.Get("Content-Type") if contentType != "text/event-stream" { t.Errorf("Expected Content-Type text/event-stream, got %s", contentType) } // Send a notification notification := &Response{ JSONRPC: "2.0", ID: 42, Result: map[string]string{"test": "data"}, } session.SendNotification(notification) // Read the SSE event buf := make([]byte, 1024) n, err := resp.Body.Read(buf) if err != nil && err != io.EOF { t.Fatalf("Failed to read SSE event: %v", err) } data := string(buf[:n]) if !strings.HasPrefix(data, "data: ") { t.Errorf("Expected SSE data event, got: %s", data) } // Parse the JSON from the SSE event jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ") var received Response if err := json.Unmarshal([]byte(jsonData), &received); err != nil { t.Fatalf("Failed to parse notification JSON: %v", err) } // JSON unmarshal converts numbers to float64, so compare as float64 receivedID, ok := received.ID.(float64) if !ok { t.Fatalf("Expected numeric ID, got %T", received.ID) } if int(receivedID) != 42 { t.Errorf("Expected notification ID 42, got %v", receivedID) } } func TestHTTPTransportParseError(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) // Send invalid JSON req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json"))) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode) } var jsonResp Response if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil { t.Fatalf("Failed to decode response: %v", err) } if jsonResp.Error == nil { t.Error("Expected JSON-RPC error for parse error") } if jsonResp.Error != nil && jsonResp.Error.Code != ParseError { t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code) } } func TestHTTPTransportMethodNotAllowed(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{}) req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil) resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusMethodNotAllowed { t.Errorf("Expected 405, got %d", resp.StatusCode) } } func TestHTTPTransportOptionsRequest(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ AllowedOrigins: []string{"http://example.com"}, }) req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil) req.Header.Set("Origin", "http://example.com") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNoContent { t.Errorf("Expected 204, got %d", resp.StatusCode) } if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" { t.Error("Expected CORS origin header") } if resp.Header.Get("Access-Control-Allow-Methods") == "" { t.Error("Expected CORS methods header") } } func TestHTTPTransportDefaultConfig(t *testing.T) { server := NewServer(nil, log.New(io.Discard, "", 0)) transport := NewHTTPTransport(server, HTTPConfig{}) // Verify defaults are applied if transport.config.Address != "127.0.0.1:8080" { t.Errorf("Expected default address 127.0.0.1:8080, got %s", transport.config.Address) } if transport.config.Endpoint != "/mcp" { t.Errorf("Expected default endpoint /mcp, got %s", transport.config.Endpoint) } if transport.config.SessionTTL != 30*time.Minute { t.Errorf("Expected default session TTL 30m, got %v", transport.config.SessionTTL) } if transport.config.MaxRequestSize != DefaultMaxRequestSize { t.Errorf("Expected default max request size %d, got %d", DefaultMaxRequestSize, transport.config.MaxRequestSize) } if transport.config.ReadTimeout != DefaultReadTimeout { t.Errorf("Expected default read timeout %v, got %v", DefaultReadTimeout, transport.config.ReadTimeout) } if transport.config.WriteTimeout != DefaultWriteTimeout { t.Errorf("Expected default write timeout %v, got %v", DefaultWriteTimeout, transport.config.WriteTimeout) } if transport.config.IdleTimeout != DefaultIdleTimeout { t.Errorf("Expected default idle timeout %v, got %v", DefaultIdleTimeout, transport.config.IdleTimeout) } if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout { t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout) } transport.sessions.Stop() } func TestHTTPTransportCustomConfig(t *testing.T) { server := NewServer(nil, log.New(io.Discard, "", 0)) config := HTTPConfig{ Address: "0.0.0.0:9090", Endpoint: "/api/mcp", SessionTTL: 1 * time.Hour, MaxRequestSize: 5 << 20, // 5MB ReadTimeout: 60 * time.Second, WriteTimeout: 60 * time.Second, IdleTimeout: 300 * time.Second, ReadHeaderTimeout: 20 * time.Second, } transport := NewHTTPTransport(server, config) // Verify custom values are preserved if transport.config.Address != "0.0.0.0:9090" { t.Errorf("Expected custom address, got %s", transport.config.Address) } if transport.config.Endpoint != "/api/mcp" { t.Errorf("Expected custom endpoint, got %s", transport.config.Endpoint) } if transport.config.SessionTTL != 1*time.Hour { t.Errorf("Expected custom session TTL, got %v", transport.config.SessionTTL) } if transport.config.MaxRequestSize != 5<<20 { t.Errorf("Expected custom max request size, got %d", transport.config.MaxRequestSize) } if transport.config.ReadTimeout != 60*time.Second { t.Errorf("Expected custom read timeout, got %v", transport.config.ReadTimeout) } if transport.config.WriteTimeout != 60*time.Second { t.Errorf("Expected custom write timeout, got %v", transport.config.WriteTimeout) } if transport.config.IdleTimeout != 300*time.Second { t.Errorf("Expected custom idle timeout, got %v", transport.config.IdleTimeout) } if transport.config.ReadHeaderTimeout != 20*time.Second { t.Errorf("Expected custom read header timeout, got %v", transport.config.ReadHeaderTimeout) } transport.sessions.Stop() } func TestHTTPTransportRequestBodyTooLarge(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ MaxRequestSize: 100, // Very small limit for testing }) // Create a request body larger than the limit largeBody := make([]byte, 200) for i := range largeBody { largeBody[i] = 'x' } req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody)) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusRequestEntityTooLarge { t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode) } } func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) { _, ts := testHTTPTransport(t, HTTPConfig{ MaxRequestSize: 10000, // Reasonable limit }) // Send initialize request (should be well within limit) initReq := Request{ JSONRPC: "2.0", ID: 1, Method: MethodInitialize, Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), } body, _ := json.Marshal(initReq) req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("Request failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode) } } func TestIsLocalhostOrigin(t *testing.T) { tests := []struct { origin string expected bool }{ {"http://localhost", true}, {"http://localhost:3000", true}, {"https://localhost", true}, {"https://localhost:8443", true}, {"http://127.0.0.1", true}, {"http://127.0.0.1:8080", true}, {"https://127.0.0.1", true}, {"http://[::1]", true}, {"http://[::1]:8080", true}, {"https://[::1]", true}, {"http://example.com", false}, {"https://example.com", false}, {"http://localhost.evil.com", false}, {"http://192.168.1.1", false}, } for _, tt := range tests { t.Run(tt.origin, func(t *testing.T) { result := isLocalhostOrigin(tt.origin) if result != tt.expected { t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected) } }) } }