feature/streamable-http-transport #1
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
result
|
result
|
||||||
|
*.db
|
||||||
|
|||||||
@@ -19,8 +19,14 @@ type HTTPConfig struct {
|
|||||||
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
||||||
TLSCertFile string // TLS certificate file (optional)
|
TLSCertFile string // TLS certificate file (optional)
|
||||||
TLSKeyFile string // TLS key file (optional)
|
TLSKeyFile string // TLS key file (optional)
|
||||||
|
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultMaxRequestSize is the default maximum request body size (1MB).
|
||||||
|
DefaultMaxRequestSize = 1 << 20 // 1MB
|
||||||
|
)
|
||||||
|
|
||||||
// HTTPTransport implements the MCP Streamable HTTP transport.
|
// HTTPTransport implements the MCP Streamable HTTP transport.
|
||||||
type HTTPTransport struct {
|
type HTTPTransport struct {
|
||||||
server *Server
|
server *Server
|
||||||
@@ -39,6 +45,9 @@ func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
|||||||
if config.SessionTTL == 0 {
|
if config.SessionTTL == 0 {
|
||||||
config.SessionTTL = 30 * time.Minute
|
config.SessionTTL = 30 * time.Minute
|
||||||
}
|
}
|
||||||
|
if config.MaxRequestSize == 0 {
|
||||||
|
config.MaxRequestSize = DefaultMaxRequestSize
|
||||||
|
}
|
||||||
|
|
||||||
return &HTTPTransport{
|
return &HTTPTransport{
|
||||||
server: server,
|
server: server,
|
||||||
@@ -113,9 +122,17 @@ func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// handlePost handles JSON-RPC requests.
|
// handlePost handles JSON-RPC requests.
|
||||||
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Limit request body size to prevent memory exhaustion attacks
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize)
|
||||||
|
|
||||||
// Read request body
|
// Read request body
|
||||||
body, err := io.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Check if this is a size limit error
|
||||||
|
if err.Error() == "http: request body too large" {
|
||||||
|
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
||||||
|
return
|
||||||
|
}
|
||||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -481,6 +481,59 @@ func TestHTTPTransportOptionsRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
MaxRequestSize: 100, // Very small limit for testing
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a request body larger than the limit
|
||||||
|
largeBody := make([]byte, 200)
|
||||||
|
for i := range largeBody {
|
||||||
|
largeBody[i] = 'x'
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusRequestEntityTooLarge {
|
||||||
|
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
MaxRequestSize: 10000, // Reasonable limit
|
||||||
|
})
|
||||||
|
|
||||||
|
// Send initialize request (should be well within limit)
|
||||||
|
initReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodInitialize,
|
||||||
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(initReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsLocalhostOrigin(t *testing.T) {
|
func TestIsLocalhostOrigin(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
origin string
|
origin string
|
||||||
|
|||||||
Reference in New Issue
Block a user