This repository has been archived on 2026-03-10. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
labmcp/internal/mcp/transport_http_test.go
Torjus Håkestad 149832e4e5 security: add request body size limit to prevent DoS
Add MaxRequestSize configuration to HTTPConfig with a default of 1MB.
Use http.MaxBytesReader to enforce the limit, returning 413 Request
Entity Too Large when exceeded.

This prevents memory exhaustion attacks where an attacker sends
arbitrarily large request bodies.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:04:11 +01:00

567 lines
15 KiB
Go

package mcp
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// testHTTPTransport creates a transport with a test server
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
// Use a mock store
server := NewServer(nil, log.New(io.Discard, "", 0))
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
transport := NewHTTPTransport(server, config)
// Create test server
mux := http.NewServeMux()
endpoint := config.Endpoint
if endpoint == "" {
endpoint = "/mcp"
}
mux.HandleFunc(endpoint, transport.handleMCP)
ts := httptest.NewServer(mux)
t.Cleanup(func() {
ts.Close()
transport.sessions.Stop()
})
return transport, ts
}
func TestHTTPTransportInitialize(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send initialize request
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200, got %d", resp.StatusCode)
}
// Check session ID header
sessionID := resp.Header.Get("Mcp-Session-Id")
if sessionID == "" {
t.Error("Expected Mcp-Session-Id header")
}
if len(sessionID) != 32 {
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
}
// Check response body
var initResp Response
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if initResp.Error != nil {
t.Errorf("Initialize failed: %v", initResp.Error)
}
}
func TestHTTPTransportSessionRequired(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send tools/list without session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportInvalidSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send request with invalid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportValidSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
// Create session manually
session, _ := transport.sessions.Create()
// Send tools/list with valid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportNotificationAccepted(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Send notification (no ID)
notification := Request{
JSONRPC: "2.0",
Method: MethodInitialized,
}
body, _ := json.Marshal(notification)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
}
// Verify session is marked as initialized
if !session.IsInitialized() {
t.Error("Session should be marked as initialized")
}
}
func TestHTTPTransportDeleteSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Delete session
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
}
// Verify session is gone
if transport.sessions.Get(session.ID) != nil {
t.Error("Session should be deleted")
}
}
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", "nonexistent")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOriginValidation(t *testing.T) {
tests := []struct {
name string
allowedOrigins []string
origin string
expectAllowed bool
}{
{
name: "no origin header",
allowedOrigins: nil,
origin: "",
expectAllowed: true,
},
{
name: "localhost allowed by default",
allowedOrigins: nil,
origin: "http://localhost:3000",
expectAllowed: true,
},
{
name: "127.0.0.1 allowed by default",
allowedOrigins: nil,
origin: "http://127.0.0.1:8080",
expectAllowed: true,
},
{
name: "external origin blocked by default",
allowedOrigins: nil,
origin: "http://evil.com",
expectAllowed: false,
},
{
name: "explicit allow",
allowedOrigins: []string{"http://example.com"},
origin: "http://example.com",
expectAllowed: true,
},
{
name: "explicit allow wildcard",
allowedOrigins: []string{"*"},
origin: "http://anything.com",
expectAllowed: true,
},
{
name: "not in allowed list",
allowedOrigins: []string{"http://example.com"},
origin: "http://other.com",
expectAllowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: tt.allowedOrigins,
})
// Use initialize since it doesn't require a session
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
t.Error("Expected request to be allowed but was forbidden")
}
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
}
})
}
}
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// GET without Accept: text/event-stream
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotAcceptable {
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
}
}
func TestHTTPTransportSSEStream(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Start SSE stream in goroutine
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected 200, got %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
}
// Send a notification
notification := &Response{
JSONRPC: "2.0",
ID: 42,
Result: map[string]string{"test": "data"},
}
session.SendNotification(notification)
// Read the SSE event
buf := make([]byte, 1024)
n, err := resp.Body.Read(buf)
if err != nil && err != io.EOF {
t.Fatalf("Failed to read SSE event: %v", err)
}
data := string(buf[:n])
if !strings.HasPrefix(data, "data: ") {
t.Errorf("Expected SSE data event, got: %s", data)
}
// Parse the JSON from the SSE event
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
var received Response
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
t.Fatalf("Failed to parse notification JSON: %v", err)
}
// JSON unmarshal converts numbers to float64, so compare as float64
receivedID, ok := received.ID.(float64)
if !ok {
t.Fatalf("Expected numeric ID, got %T", received.ID)
}
if int(receivedID) != 42 {
t.Errorf("Expected notification ID 42, got %v", receivedID)
}
}
func TestHTTPTransportParseError(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send invalid JSON
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
}
var jsonResp Response
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if jsonResp.Error == nil {
t.Error("Expected JSON-RPC error for parse error")
}
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
}
}
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected 405, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOptionsRequest(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: []string{"http://example.com"},
})
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
req.Header.Set("Origin", "http://example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204, got %d", resp.StatusCode)
}
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Error("Expected CORS origin header")
}
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
t.Error("Expected CORS methods header")
}
}
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 100, // Very small limit for testing
})
// Create a request body larger than the limit
largeBody := make([]byte, 200)
for i := range largeBody {
largeBody[i] = 'x'
}
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusRequestEntityTooLarge {
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
}
}
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 10000, // Reasonable limit
})
// Send initialize request (should be well within limit)
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
}
}
func TestIsLocalhostOrigin(t *testing.T) {
tests := []struct {
origin string
expected bool
}{
{"http://localhost", true},
{"http://localhost:3000", true},
{"https://localhost", true},
{"https://localhost:8443", true},
{"http://127.0.0.1", true},
{"http://127.0.0.1:8080", true},
{"https://127.0.0.1", true},
{"http://[::1]", true},
{"http://[::1]:8080", true},
{"https://[::1]", true},
{"http://example.com", false},
{"https://example.com", false},
{"http://localhost.evil.com", false},
{"http://192.168.1.1", false},
}
for _, tt := range tests {
t.Run(tt.origin, func(t *testing.T) {
result := isLocalhostOrigin(tt.origin)
if result != tt.expected {
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
}
})
}
}