feat: add Streamable HTTP transport support
Add support for running the MCP server over HTTP with Server-Sent Events (SSE) using the MCP Streamable HTTP specification, alongside the existing STDIO transport. New features: - Transport abstraction with Transport interface - HTTP transport with session management - SSE support for server-initiated notifications - CORS security with configurable allowed origins - Optional TLS support - CLI flags for HTTP configuration (--transport, --http-address, etc.) - NixOS module options for HTTP transport The HTTP transport implements: - POST /mcp: JSON-RPC requests with session management - GET /mcp: SSE stream for server notifications - DELETE /mcp: Session termination - Origin validation (localhost-only by default) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
513
internal/mcp/transport_http_test.go
Normal file
513
internal/mcp/transport_http_test.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// testHTTPTransport creates a transport with a test server
|
||||
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
|
||||
// Use a mock store
|
||||
server := NewServer(nil, log.New(io.Discard, "", 0))
|
||||
|
||||
if config.SessionTTL == 0 {
|
||||
config.SessionTTL = 30 * time.Minute
|
||||
}
|
||||
transport := NewHTTPTransport(server, config)
|
||||
|
||||
// Create test server
|
||||
mux := http.NewServeMux()
|
||||
endpoint := config.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/mcp"
|
||||
}
|
||||
mux.HandleFunc(endpoint, transport.handleMCP)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
t.Cleanup(func() {
|
||||
ts.Close()
|
||||
transport.sessions.Stop()
|
||||
})
|
||||
|
||||
return transport, ts
|
||||
}
|
||||
|
||||
func TestHTTPTransportInitialize(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send initialize request
|
||||
initReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodInitialize,
|
||||
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||
}
|
||||
body, _ := json.Marshal(initReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check session ID header
|
||||
sessionID := resp.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
t.Error("Expected Mcp-Session-Id header")
|
||||
}
|
||||
if len(sessionID) != 32 {
|
||||
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var initResp Response
|
||||
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if initResp.Error != nil {
|
||||
t.Errorf("Initialize failed: %v", initResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSessionRequired(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send tools/list without session
|
||||
listReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodToolsList,
|
||||
}
|
||||
body, _ := json.Marshal(listReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportInvalidSession(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send request with invalid session
|
||||
listReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodToolsList,
|
||||
}
|
||||
body, _ := json.Marshal(listReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportValidSession(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Create session manually
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Send tools/list with valid session
|
||||
listReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodToolsList,
|
||||
}
|
||||
body, _ := json.Marshal(listReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportNotificationAccepted(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Send notification (no ID)
|
||||
notification := Request{
|
||||
JSONRPC: "2.0",
|
||||
Method: MethodInitialized,
|
||||
}
|
||||
body, _ := json.Marshal(notification)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Verify session is marked as initialized
|
||||
if !session.IsInitialized() {
|
||||
t.Error("Session should be marked as initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportDeleteSession(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Delete session
|
||||
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Verify session is gone
|
||||
if transport.sessions.Get(session.ID) != nil {
|
||||
t.Error("Session should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", "nonexistent")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportOriginValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedOrigins []string
|
||||
origin string
|
||||
expectAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "no origin header",
|
||||
allowedOrigins: nil,
|
||||
origin: "",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "localhost allowed by default",
|
||||
allowedOrigins: nil,
|
||||
origin: "http://localhost:3000",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 allowed by default",
|
||||
allowedOrigins: nil,
|
||||
origin: "http://127.0.0.1:8080",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "external origin blocked by default",
|
||||
allowedOrigins: nil,
|
||||
origin: "http://evil.com",
|
||||
expectAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "explicit allow",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
origin: "http://example.com",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "explicit allow wildcard",
|
||||
allowedOrigins: []string{"*"},
|
||||
origin: "http://anything.com",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "not in allowed list",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
origin: "http://other.com",
|
||||
expectAllowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||
AllowedOrigins: tt.allowedOrigins,
|
||||
})
|
||||
|
||||
// Use initialize since it doesn't require a session
|
||||
initReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodInitialize,
|
||||
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||
}
|
||||
body, _ := json.Marshal(initReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.origin != "" {
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
|
||||
t.Error("Expected request to be allowed but was forbidden")
|
||||
}
|
||||
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// GET without Accept: text/event-stream
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotAcceptable {
|
||||
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSSEStream(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Start SSE stream in goroutine
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "text/event-stream" {
|
||||
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
notification := &Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: 42,
|
||||
Result: map[string]string{"test": "data"},
|
||||
}
|
||||
session.SendNotification(notification)
|
||||
|
||||
// Read the SSE event
|
||||
buf := make([]byte, 1024)
|
||||
n, err := resp.Body.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Failed to read SSE event: %v", err)
|
||||
}
|
||||
|
||||
data := string(buf[:n])
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
t.Errorf("Expected SSE data event, got: %s", data)
|
||||
}
|
||||
|
||||
// Parse the JSON from the SSE event
|
||||
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
|
||||
var received Response
|
||||
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
|
||||
t.Fatalf("Failed to parse notification JSON: %v", err)
|
||||
}
|
||||
|
||||
// JSON unmarshal converts numbers to float64, so compare as float64
|
||||
receivedID, ok := received.ID.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected numeric ID, got %T", received.ID)
|
||||
}
|
||||
if int(receivedID) != 42 {
|
||||
t.Errorf("Expected notification ID 42, got %v", receivedID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportParseError(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send invalid JSON
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var jsonResp Response
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if jsonResp.Error == nil {
|
||||
t.Error("Expected JSON-RPC error for parse error")
|
||||
}
|
||||
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
|
||||
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected 405, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportOptionsRequest(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected 204, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Error("Expected CORS origin header")
|
||||
}
|
||||
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Error("Expected CORS methods header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalhostOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
origin string
|
||||
expected bool
|
||||
}{
|
||||
{"http://localhost", true},
|
||||
{"http://localhost:3000", true},
|
||||
{"https://localhost", true},
|
||||
{"https://localhost:8443", true},
|
||||
{"http://127.0.0.1", true},
|
||||
{"http://127.0.0.1:8080", true},
|
||||
{"https://127.0.0.1", true},
|
||||
{"http://[::1]", true},
|
||||
{"http://[::1]:8080", true},
|
||||
{"https://[::1]", true},
|
||||
{"http://example.com", false},
|
||||
{"https://example.com", false},
|
||||
{"http://localhost.evil.com", false},
|
||||
{"http://192.168.1.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.origin, func(t *testing.T) {
|
||||
result := isLocalhostOrigin(tt.origin)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user