Add MaxRequestSize configuration to HTTPConfig with a default of 1MB. Use http.MaxBytesReader to enforce the limit, returning 413 Request Entity Too Large when exceeded. This prevents memory exhaustion attacks where an attacker sends arbitrarily large request bodies. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
567 lines
15 KiB
Go
567 lines
15 KiB
Go
package mcp
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// testHTTPTransport creates a transport with a test server
|
|
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
|
|
// Use a mock store
|
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
|
|
|
if config.SessionTTL == 0 {
|
|
config.SessionTTL = 30 * time.Minute
|
|
}
|
|
transport := NewHTTPTransport(server, config)
|
|
|
|
// Create test server
|
|
mux := http.NewServeMux()
|
|
endpoint := config.Endpoint
|
|
if endpoint == "" {
|
|
endpoint = "/mcp"
|
|
}
|
|
mux.HandleFunc(endpoint, transport.handleMCP)
|
|
|
|
ts := httptest.NewServer(mux)
|
|
t.Cleanup(func() {
|
|
ts.Close()
|
|
transport.sessions.Stop()
|
|
})
|
|
|
|
return transport, ts
|
|
}
|
|
|
|
func TestHTTPTransportInitialize(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send initialize request
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Check session ID header
|
|
sessionID := resp.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
t.Error("Expected Mcp-Session-Id header")
|
|
}
|
|
if len(sessionID) != 32 {
|
|
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
|
|
}
|
|
|
|
// Check response body
|
|
var initResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
|
|
t.Fatalf("Failed to decode response: %v", err)
|
|
}
|
|
|
|
if initResp.Error != nil {
|
|
t.Errorf("Initialize failed: %v", initResp.Error)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSessionRequired(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send tools/list without session
|
|
listReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodToolsList,
|
|
}
|
|
body, _ := json.Marshal(listReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportInvalidSession(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send request with invalid session
|
|
listReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodToolsList,
|
|
}
|
|
body, _ := json.Marshal(listReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportValidSession(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Create session manually
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Send tools/list with valid session
|
|
listReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodToolsList,
|
|
}
|
|
body, _ := json.Marshal(listReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportNotificationAccepted(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Send notification (no ID)
|
|
notification := Request{
|
|
JSONRPC: "2.0",
|
|
Method: MethodInitialized,
|
|
}
|
|
body, _ := json.Marshal(notification)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusAccepted {
|
|
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Verify session is marked as initialized
|
|
if !session.IsInitialized() {
|
|
t.Error("Session should be marked as initialized")
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportDeleteSession(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Delete session
|
|
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Verify session is gone
|
|
if transport.sessions.Get(session.ID) != nil {
|
|
t.Error("Session should be deleted")
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", "nonexistent")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportOriginValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedOrigins []string
|
|
origin string
|
|
expectAllowed bool
|
|
}{
|
|
{
|
|
name: "no origin header",
|
|
allowedOrigins: nil,
|
|
origin: "",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "localhost allowed by default",
|
|
allowedOrigins: nil,
|
|
origin: "http://localhost:3000",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "127.0.0.1 allowed by default",
|
|
allowedOrigins: nil,
|
|
origin: "http://127.0.0.1:8080",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "external origin blocked by default",
|
|
allowedOrigins: nil,
|
|
origin: "http://evil.com",
|
|
expectAllowed: false,
|
|
},
|
|
{
|
|
name: "explicit allow",
|
|
allowedOrigins: []string{"http://example.com"},
|
|
origin: "http://example.com",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "explicit allow wildcard",
|
|
allowedOrigins: []string{"*"},
|
|
origin: "http://anything.com",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "not in allowed list",
|
|
allowedOrigins: []string{"http://example.com"},
|
|
origin: "http://other.com",
|
|
expectAllowed: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
AllowedOrigins: tt.allowedOrigins,
|
|
})
|
|
|
|
// Use initialize since it doesn't require a session
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if tt.origin != "" {
|
|
req.Header.Set("Origin", tt.origin)
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
|
|
t.Error("Expected request to be allowed but was forbidden")
|
|
}
|
|
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// GET without Accept: text/event-stream
|
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotAcceptable {
|
|
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSSEStream(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Start SSE stream in goroutine
|
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if contentType != "text/event-stream" {
|
|
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
|
|
}
|
|
|
|
// Send a notification
|
|
notification := &Response{
|
|
JSONRPC: "2.0",
|
|
ID: 42,
|
|
Result: map[string]string{"test": "data"},
|
|
}
|
|
session.SendNotification(notification)
|
|
|
|
// Read the SSE event
|
|
buf := make([]byte, 1024)
|
|
n, err := resp.Body.Read(buf)
|
|
if err != nil && err != io.EOF {
|
|
t.Fatalf("Failed to read SSE event: %v", err)
|
|
}
|
|
|
|
data := string(buf[:n])
|
|
if !strings.HasPrefix(data, "data: ") {
|
|
t.Errorf("Expected SSE data event, got: %s", data)
|
|
}
|
|
|
|
// Parse the JSON from the SSE event
|
|
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
|
|
var received Response
|
|
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
|
|
t.Fatalf("Failed to parse notification JSON: %v", err)
|
|
}
|
|
|
|
// JSON unmarshal converts numbers to float64, so compare as float64
|
|
receivedID, ok := received.ID.(float64)
|
|
if !ok {
|
|
t.Fatalf("Expected numeric ID, got %T", received.ID)
|
|
}
|
|
if int(receivedID) != 42 {
|
|
t.Errorf("Expected notification ID 42, got %v", receivedID)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportParseError(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send invalid JSON
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
|
|
}
|
|
|
|
var jsonResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
|
|
t.Fatalf("Failed to decode response: %v", err)
|
|
}
|
|
|
|
if jsonResp.Error == nil {
|
|
t.Error("Expected JSON-RPC error for parse error")
|
|
}
|
|
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
|
|
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusMethodNotAllowed {
|
|
t.Errorf("Expected 405, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportOptionsRequest(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
AllowedOrigins: []string{"http://example.com"},
|
|
})
|
|
|
|
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
t.Errorf("Expected 204, got %d", resp.StatusCode)
|
|
}
|
|
|
|
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
|
|
t.Error("Expected CORS origin header")
|
|
}
|
|
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
|
|
t.Error("Expected CORS methods header")
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
MaxRequestSize: 100, // Very small limit for testing
|
|
})
|
|
|
|
// Create a request body larger than the limit
|
|
largeBody := make([]byte, 200)
|
|
for i := range largeBody {
|
|
largeBody[i] = 'x'
|
|
}
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusRequestEntityTooLarge {
|
|
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
MaxRequestSize: 10000, // Reasonable limit
|
|
})
|
|
|
|
// Send initialize request (should be well within limit)
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestIsLocalhostOrigin(t *testing.T) {
|
|
tests := []struct {
|
|
origin string
|
|
expected bool
|
|
}{
|
|
{"http://localhost", true},
|
|
{"http://localhost:3000", true},
|
|
{"https://localhost", true},
|
|
{"https://localhost:8443", true},
|
|
{"http://127.0.0.1", true},
|
|
{"http://127.0.0.1:8080", true},
|
|
{"https://127.0.0.1", true},
|
|
{"http://[::1]", true},
|
|
{"http://[::1]:8080", true},
|
|
{"https://[::1]", true},
|
|
{"http://example.com", false},
|
|
{"https://example.com", false},
|
|
{"http://localhost.evil.com", false},
|
|
{"http://192.168.1.1", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.origin, func(t *testing.T) {
|
|
result := isLocalhostOrigin(tt.origin)
|
|
if result != tt.expected {
|
|
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|