diff --git a/.gitignore b/.gitignore index b2be92b..22f5618 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ result +*.db diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go index 5d21ac3..834e4fd 100644 --- a/internal/mcp/transport_http.go +++ b/internal/mcp/transport_http.go @@ -19,8 +19,14 @@ type HTTPConfig struct { SessionTTL time.Duration // Session TTL (default: 30 minutes) TLSCertFile string // TLS certificate file (optional) TLSKeyFile string // TLS key file (optional) + MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB) } +const ( + // DefaultMaxRequestSize is the default maximum request body size (1MB). + DefaultMaxRequestSize = 1 << 20 // 1MB +) + // HTTPTransport implements the MCP Streamable HTTP transport. type HTTPTransport struct { server *Server @@ -39,6 +45,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { if config.SessionTTL == 0 { config.SessionTTL = 30 * time.Minute } + if config.MaxRequestSize == 0 { + config.MaxRequestSize = DefaultMaxRequestSize + } return &HTTPTransport{ server: server, @@ -113,9 +122,17 @@ func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) { // handlePost handles JSON-RPC requests. func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) { + // Limit request body size to prevent memory exhaustion attacks + r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize) + // Read request body body, err := io.ReadAll(r.Body) if err != nil { + // Check if this is a size limit error + if err.Error() == "http: request body too large" { + http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge) + return + } http.Error(w, "Failed to read request body", http.StatusBadRequest) return } diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go index 06f4655..f9354cf 100644 --- a/internal/mcp/transport_http_test.go +++ b/internal/mcp/transport_http_test.go @@ -481,6 +481,59 @@ func TestHTTPTransportOptionsRequest(t *testing.T) { } } +func TestHTTPTransportRequestBodyTooLarge(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + MaxRequestSize: 100, // Very small limit for testing + }) + + // Create a request body larger than the limit + largeBody := make([]byte, 200) + for i := range largeBody { + largeBody[i] = 'x' + } + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusRequestEntityTooLarge { + t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + MaxRequestSize: 10000, // Reasonable limit + }) + + // Send initialize request (should be well within limit) + initReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodInitialize, + Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), + } + body, _ := json.Marshal(initReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode) + } +} + func TestIsLocalhostOrigin(t *testing.T) { tests := []struct { origin string