Add configurable MaxSessions limit (default: 10000) to SessionStore. When the limit is reached, new session creation returns ErrTooManySessions and HTTP transport responds with 503 Service Unavailable. This prevents attackers from exhausting server memory by creating unlimited sessions through repeated initialize requests. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
687 lines
19 KiB
Go
687 lines
19 KiB
Go
package mcp
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// testHTTPTransport creates a transport with a test server
|
|
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
|
|
// Use a mock store
|
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
|
|
|
if config.SessionTTL == 0 {
|
|
config.SessionTTL = 30 * time.Minute
|
|
}
|
|
transport := NewHTTPTransport(server, config)
|
|
|
|
// Create test server
|
|
mux := http.NewServeMux()
|
|
endpoint := config.Endpoint
|
|
if endpoint == "" {
|
|
endpoint = "/mcp"
|
|
}
|
|
mux.HandleFunc(endpoint, transport.handleMCP)
|
|
|
|
ts := httptest.NewServer(mux)
|
|
t.Cleanup(func() {
|
|
ts.Close()
|
|
transport.sessions.Stop()
|
|
})
|
|
|
|
return transport, ts
|
|
}
|
|
|
|
func TestHTTPTransportInitialize(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send initialize request
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Check session ID header
|
|
sessionID := resp.Header.Get("Mcp-Session-Id")
|
|
if sessionID == "" {
|
|
t.Error("Expected Mcp-Session-Id header")
|
|
}
|
|
if len(sessionID) != 32 {
|
|
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
|
|
}
|
|
|
|
// Check response body
|
|
var initResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
|
|
t.Fatalf("Failed to decode response: %v", err)
|
|
}
|
|
|
|
if initResp.Error != nil {
|
|
t.Errorf("Initialize failed: %v", initResp.Error)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSessionRequired(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send tools/list without session
|
|
listReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodToolsList,
|
|
}
|
|
body, _ := json.Marshal(listReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusBadRequest {
|
|
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportInvalidSession(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send request with invalid session
|
|
listReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodToolsList,
|
|
}
|
|
body, _ := json.Marshal(listReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportValidSession(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Create session manually
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Send tools/list with valid session
|
|
listReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodToolsList,
|
|
}
|
|
body, _ := json.Marshal(listReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportNotificationAccepted(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Send notification (no ID)
|
|
notification := Request{
|
|
JSONRPC: "2.0",
|
|
Method: MethodInitialized,
|
|
}
|
|
body, _ := json.Marshal(notification)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusAccepted {
|
|
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Verify session is marked as initialized
|
|
if !session.IsInitialized() {
|
|
t.Error("Session should be marked as initialized")
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportDeleteSession(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Delete session
|
|
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
|
|
}
|
|
|
|
// Verify session is gone
|
|
if transport.sessions.Get(session.ID) != nil {
|
|
t.Error("Session should be deleted")
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", "nonexistent")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportOriginValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
allowedOrigins []string
|
|
origin string
|
|
expectAllowed bool
|
|
}{
|
|
{
|
|
name: "no origin header",
|
|
allowedOrigins: nil,
|
|
origin: "",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "localhost allowed by default",
|
|
allowedOrigins: nil,
|
|
origin: "http://localhost:3000",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "127.0.0.1 allowed by default",
|
|
allowedOrigins: nil,
|
|
origin: "http://127.0.0.1:8080",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "external origin blocked by default",
|
|
allowedOrigins: nil,
|
|
origin: "http://evil.com",
|
|
expectAllowed: false,
|
|
},
|
|
{
|
|
name: "explicit allow",
|
|
allowedOrigins: []string{"http://example.com"},
|
|
origin: "http://example.com",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "explicit allow wildcard",
|
|
allowedOrigins: []string{"*"},
|
|
origin: "http://anything.com",
|
|
expectAllowed: true,
|
|
},
|
|
{
|
|
name: "not in allowed list",
|
|
allowedOrigins: []string{"http://example.com"},
|
|
origin: "http://other.com",
|
|
expectAllowed: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
AllowedOrigins: tt.allowedOrigins,
|
|
})
|
|
|
|
// Use initialize since it doesn't require a session
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if tt.origin != "" {
|
|
req.Header.Set("Origin", tt.origin)
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
|
|
t.Error("Expected request to be allowed but was forbidden")
|
|
}
|
|
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
|
|
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// GET without Accept: text/event-stream
|
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNotAcceptable {
|
|
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSSEStream(t *testing.T) {
|
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
session, _ := transport.sessions.Create()
|
|
|
|
// Start SSE stream in goroutine
|
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if contentType != "text/event-stream" {
|
|
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
|
|
}
|
|
|
|
// Send a notification
|
|
notification := &Response{
|
|
JSONRPC: "2.0",
|
|
ID: 42,
|
|
Result: map[string]string{"test": "data"},
|
|
}
|
|
session.SendNotification(notification)
|
|
|
|
// Read the SSE event
|
|
buf := make([]byte, 1024)
|
|
n, err := resp.Body.Read(buf)
|
|
if err != nil && err != io.EOF {
|
|
t.Fatalf("Failed to read SSE event: %v", err)
|
|
}
|
|
|
|
data := string(buf[:n])
|
|
if !strings.HasPrefix(data, "data: ") {
|
|
t.Errorf("Expected SSE data event, got: %s", data)
|
|
}
|
|
|
|
// Parse the JSON from the SSE event
|
|
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
|
|
var received Response
|
|
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
|
|
t.Fatalf("Failed to parse notification JSON: %v", err)
|
|
}
|
|
|
|
// JSON unmarshal converts numbers to float64, so compare as float64
|
|
receivedID, ok := received.ID.(float64)
|
|
if !ok {
|
|
t.Fatalf("Expected numeric ID, got %T", received.ID)
|
|
}
|
|
if int(receivedID) != 42 {
|
|
t.Errorf("Expected notification ID 42, got %v", receivedID)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportParseError(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
// Send invalid JSON
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
|
|
}
|
|
|
|
var jsonResp Response
|
|
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
|
|
t.Fatalf("Failed to decode response: %v", err)
|
|
}
|
|
|
|
if jsonResp.Error == nil {
|
|
t.Error("Expected JSON-RPC error for parse error")
|
|
}
|
|
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
|
|
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
|
|
|
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusMethodNotAllowed {
|
|
t.Errorf("Expected 405, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportOptionsRequest(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
AllowedOrigins: []string{"http://example.com"},
|
|
})
|
|
|
|
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusNoContent {
|
|
t.Errorf("Expected 204, got %d", resp.StatusCode)
|
|
}
|
|
|
|
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
|
|
t.Error("Expected CORS origin header")
|
|
}
|
|
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
|
|
t.Error("Expected CORS methods header")
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportDefaultConfig(t *testing.T) {
|
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
|
transport := NewHTTPTransport(server, HTTPConfig{})
|
|
|
|
// Verify defaults are applied
|
|
if transport.config.Address != "127.0.0.1:8080" {
|
|
t.Errorf("Expected default address 127.0.0.1:8080, got %s", transport.config.Address)
|
|
}
|
|
if transport.config.Endpoint != "/mcp" {
|
|
t.Errorf("Expected default endpoint /mcp, got %s", transport.config.Endpoint)
|
|
}
|
|
if transport.config.SessionTTL != 30*time.Minute {
|
|
t.Errorf("Expected default session TTL 30m, got %v", transport.config.SessionTTL)
|
|
}
|
|
if transport.config.MaxRequestSize != DefaultMaxRequestSize {
|
|
t.Errorf("Expected default max request size %d, got %d", DefaultMaxRequestSize, transport.config.MaxRequestSize)
|
|
}
|
|
if transport.config.ReadTimeout != DefaultReadTimeout {
|
|
t.Errorf("Expected default read timeout %v, got %v", DefaultReadTimeout, transport.config.ReadTimeout)
|
|
}
|
|
if transport.config.WriteTimeout != DefaultWriteTimeout {
|
|
t.Errorf("Expected default write timeout %v, got %v", DefaultWriteTimeout, transport.config.WriteTimeout)
|
|
}
|
|
if transport.config.IdleTimeout != DefaultIdleTimeout {
|
|
t.Errorf("Expected default idle timeout %v, got %v", DefaultIdleTimeout, transport.config.IdleTimeout)
|
|
}
|
|
if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout {
|
|
t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout)
|
|
}
|
|
|
|
transport.sessions.Stop()
|
|
}
|
|
|
|
func TestHTTPTransportCustomConfig(t *testing.T) {
|
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
|
config := HTTPConfig{
|
|
Address: "0.0.0.0:9090",
|
|
Endpoint: "/api/mcp",
|
|
SessionTTL: 1 * time.Hour,
|
|
MaxRequestSize: 5 << 20, // 5MB
|
|
ReadTimeout: 60 * time.Second,
|
|
WriteTimeout: 60 * time.Second,
|
|
IdleTimeout: 300 * time.Second,
|
|
ReadHeaderTimeout: 20 * time.Second,
|
|
}
|
|
transport := NewHTTPTransport(server, config)
|
|
|
|
// Verify custom values are preserved
|
|
if transport.config.Address != "0.0.0.0:9090" {
|
|
t.Errorf("Expected custom address, got %s", transport.config.Address)
|
|
}
|
|
if transport.config.Endpoint != "/api/mcp" {
|
|
t.Errorf("Expected custom endpoint, got %s", transport.config.Endpoint)
|
|
}
|
|
if transport.config.SessionTTL != 1*time.Hour {
|
|
t.Errorf("Expected custom session TTL, got %v", transport.config.SessionTTL)
|
|
}
|
|
if transport.config.MaxRequestSize != 5<<20 {
|
|
t.Errorf("Expected custom max request size, got %d", transport.config.MaxRequestSize)
|
|
}
|
|
if transport.config.ReadTimeout != 60*time.Second {
|
|
t.Errorf("Expected custom read timeout, got %v", transport.config.ReadTimeout)
|
|
}
|
|
if transport.config.WriteTimeout != 60*time.Second {
|
|
t.Errorf("Expected custom write timeout, got %v", transport.config.WriteTimeout)
|
|
}
|
|
if transport.config.IdleTimeout != 300*time.Second {
|
|
t.Errorf("Expected custom idle timeout, got %v", transport.config.IdleTimeout)
|
|
}
|
|
if transport.config.ReadHeaderTimeout != 20*time.Second {
|
|
t.Errorf("Expected custom read header timeout, got %v", transport.config.ReadHeaderTimeout)
|
|
}
|
|
|
|
transport.sessions.Stop()
|
|
}
|
|
|
|
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
MaxRequestSize: 100, // Very small limit for testing
|
|
})
|
|
|
|
// Create a request body larger than the limit
|
|
largeBody := make([]byte, 200)
|
|
for i := range largeBody {
|
|
largeBody[i] = 'x'
|
|
}
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusRequestEntityTooLarge {
|
|
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportSessionLimitReached(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
MaxSessions: 2, // Very low limit for testing
|
|
})
|
|
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
// Create sessions up to the limit
|
|
for i := 0; i < 2; i++ {
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request %d failed: %v", i, err)
|
|
}
|
|
resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
// Third request should fail with 503
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
|
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
|
MaxRequestSize: 10000, // Reasonable limit
|
|
})
|
|
|
|
// Send initialize request (should be well within limit)
|
|
initReq := Request{
|
|
JSONRPC: "2.0",
|
|
ID: 1,
|
|
Method: MethodInitialize,
|
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
|
}
|
|
body, _ := json.Marshal(initReq)
|
|
|
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("Request failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestIsLocalhostOrigin(t *testing.T) {
|
|
tests := []struct {
|
|
origin string
|
|
expected bool
|
|
}{
|
|
{"http://localhost", true},
|
|
{"http://localhost:3000", true},
|
|
{"https://localhost", true},
|
|
{"https://localhost:8443", true},
|
|
{"http://127.0.0.1", true},
|
|
{"http://127.0.0.1:8080", true},
|
|
{"https://127.0.0.1", true},
|
|
{"http://[::1]", true},
|
|
{"http://[::1]:8080", true},
|
|
{"https://[::1]", true},
|
|
{"http://example.com", false},
|
|
{"https://example.com", false},
|
|
{"http://localhost.evil.com", false},
|
|
{"http://192.168.1.1", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.origin, func(t *testing.T) {
|
|
result := isLocalhostOrigin(tt.origin)
|
|
if result != tt.expected {
|
|
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|