From 149832e4e508d001f1e7d205467ba7b60d4a9d9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Tue, 3 Feb 2026 22:04:11 +0100 Subject: [PATCH] security: add request body size limit to prevent DoS Add MaxRequestSize configuration to HTTPConfig with a default of 1MB. Use http.MaxBytesReader to enforce the limit, returning 413 Request Entity Too Large when exceeded. This prevents memory exhaustion attacks where an attacker sends arbitrarily large request bodies. Co-Authored-By: Claude Opus 4.5 --- .gitignore | 1 + internal/mcp/transport_http.go | 17 +++++++++ internal/mcp/transport_http_test.go | 53 +++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/.gitignore b/.gitignore index b2be92b..22f5618 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ result +*.db diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go index 5d21ac3..834e4fd 100644 --- a/internal/mcp/transport_http.go +++ b/internal/mcp/transport_http.go @@ -19,8 +19,14 @@ type HTTPConfig struct { SessionTTL time.Duration // Session TTL (default: 30 minutes) TLSCertFile string // TLS certificate file (optional) TLSKeyFile string // TLS key file (optional) + MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) } +const ( + // DefaultMaxRequestSize is the default maximum request body size (1MB). + DefaultMaxRequestSize = 1 << 20 // 1MB +) + // HTTPTransport implements the MCP Streamable HTTP transport. type HTTPTransport struct { server *Server @@ -39,6 +45,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.SessionTTL == 0 { config.SessionTTL = 30 * time.Minute } + if config.MaxRequestSize == 0 { + config.MaxRequestSize = DefaultMaxRequestSize + } return &HTTPTransport{ server: server, @@ -113,9 +122,17 @@ func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) { // handlePost handles JSON-RPC requests. func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) { + // Limit request body size to prevent memory exhaustion attacks + r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize) + // Read request body body, err := io.ReadAll(r.Body) if err != nil { + // Check if this is a size limit error + if err.Error() == "http: request body too large" { + http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) + return + } http.Error(w, "Failed to read request body", http.StatusBadRequest) return } diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go index 06f4655..f9354cf 100644 --- a/internal/mcp/transport_http_test.go +++ b/internal/mcp/transport_http_test.go @@ -481,6 +481,59 @@ func TestHTTPTransportOptionsRequest(t *testing.T) { } } +func TestHTTPTransportRequestBodyTooLarge(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + MaxRequestSize: 100, // Very small limit for testing + }) + + // Create a request body larger than the limit + largeBody := make([]byte, 200) + for i := range largeBody { + largeBody[i] = 'x' + } + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusRequestEntityTooLarge { + t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + MaxRequestSize: 10000, // Reasonable limit + }) + + // Send initialize request (should be well within limit) + initReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodInitialize, + Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), + } + body, _ := json.Marshal(initReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode) + } +} + func TestIsLocalhostOrigin(t *testing.T) { tests := []struct { origin string