feat: add Streamable HTTP transport support
Add support for running the MCP server over HTTP with Server-Sent Events (SSE) using the MCP Streamable HTTP specification, alongside the existing STDIO transport. New features: - Transport abstraction with Transport interface - HTTP transport with session management - SSE support for server-initiated notifications - CORS security with configurable allowed origins - Optional TLS support - CLI flags for HTTP configuration (--transport, --http-address, etc.) - NixOS module options for HTTP transport The HTTP transport implements: - POST /mcp: JSON-RPC requests with session management - GET /mcp: SSE stream for server notifications - DELETE /mcp: Session termination - Origin validation (localhost-only by default) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -11,7 +10,7 @@ import (
|
||||
"git.t-juice.club/torjus/labmcp/internal/database"
|
||||
)
|
||||
|
||||
// Server is an MCP server that handles JSON-RPC requests over stdio.
|
||||
// Server is an MCP server that handles JSON-RPC requests.
|
||||
type Server struct {
|
||||
store database.Store
|
||||
tools map[string]ToolHandler
|
||||
@@ -41,53 +40,34 @@ func (s *Server) registerTools() {
|
||||
// Tools will be implemented in handlers.go
|
||||
}
|
||||
|
||||
// Run starts the server, reading from r and writing to w.
|
||||
// Run starts the server using STDIO transport (backward compatibility).
|
||||
func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error {
|
||||
scanner := bufio.NewScanner(r)
|
||||
encoder := json.NewEncoder(w)
|
||||
transport := NewStdioTransport(s, r, w)
|
||||
return transport.Run(ctx)
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
s.logger.Printf("Failed to parse request: %v", err)
|
||||
resp := Response{
|
||||
JSONRPC: "2.0",
|
||||
Error: &Error{
|
||||
Code: ParseError,
|
||||
Message: "Parse error",
|
||||
Data: err.Error(),
|
||||
},
|
||||
}
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
return fmt.Errorf("failed to write response: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
resp := s.handleRequest(ctx, &req)
|
||||
if resp != nil {
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
return fmt.Errorf("failed to write response: %w", err)
|
||||
}
|
||||
}
|
||||
// HandleMessage parses a JSON-RPC message and returns the response.
|
||||
// Returns (nil, nil) for notifications that don't require a response.
|
||||
func (s *Server) HandleMessage(ctx context.Context, data []byte) (*Response, error) {
|
||||
var req Request
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
return &Response{
|
||||
JSONRPC: "2.0",
|
||||
Error: &Error{
|
||||
Code: ParseError,
|
||||
Message: "Parse error",
|
||||
Data: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scanner error: %w", err)
|
||||
}
|
||||
return s.HandleRequest(ctx, &req), nil
|
||||
}
|
||||
|
||||
return nil
|
||||
// HandleRequest processes a single request and returns a response.
|
||||
// Returns nil for notifications that don't require a response.
|
||||
func (s *Server) HandleRequest(ctx context.Context, req *Request) *Response {
|
||||
return s.handleRequest(ctx, req)
|
||||
}
|
||||
|
||||
// handleRequest processes a single request and returns a response.
|
||||
|
||||
207
internal/mcp/session.go
Normal file
207
internal/mcp/session.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Session represents an MCP client session.
|
||||
type Session struct {
|
||||
ID string
|
||||
CreatedAt time.Time
|
||||
LastActivity time.Time
|
||||
Initialized bool
|
||||
|
||||
// notifications is a channel for server-initiated notifications.
|
||||
notifications chan *Response
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSession creates a new session with a cryptographically secure random ID.
|
||||
func NewSession() (*Session, error) {
|
||||
id, err := generateSessionID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return &Session{
|
||||
ID: id,
|
||||
CreatedAt: now,
|
||||
LastActivity: now,
|
||||
notifications: make(chan *Response, 100),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Touch updates the session's last activity time.
|
||||
func (s *Session) Touch() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
// SetInitialized marks the session as initialized.
|
||||
func (s *Session) SetInitialized() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.Initialized = true
|
||||
}
|
||||
|
||||
// IsInitialized returns whether the session has been initialized.
|
||||
func (s *Session) IsInitialized() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.Initialized
|
||||
}
|
||||
|
||||
// Notifications returns the channel for server-initiated notifications.
|
||||
func (s *Session) Notifications() <-chan *Response {
|
||||
return s.notifications
|
||||
}
|
||||
|
||||
// SendNotification sends a notification to the session.
|
||||
// Returns false if the channel is full.
|
||||
func (s *Session) SendNotification(notification *Response) bool {
|
||||
select {
|
||||
case s.notifications <- notification:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the session's notification channel.
|
||||
func (s *Session) Close() {
|
||||
close(s.notifications)
|
||||
}
|
||||
|
||||
// SessionStore manages active sessions with TTL-based cleanup.
|
||||
type SessionStore struct {
|
||||
sessions map[string]*Session
|
||||
ttl time.Duration
|
||||
mu sync.RWMutex
|
||||
stopClean chan struct{}
|
||||
cleanDone chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionStore creates a new session store with the given TTL.
|
||||
func NewSessionStore(ttl time.Duration) *SessionStore {
|
||||
s := &SessionStore{
|
||||
sessions: make(map[string]*Session),
|
||||
ttl: ttl,
|
||||
stopClean: make(chan struct{}),
|
||||
cleanDone: make(chan struct{}),
|
||||
}
|
||||
go s.cleanupLoop()
|
||||
return s
|
||||
}
|
||||
|
||||
// Create creates a new session and adds it to the store.
|
||||
func (s *SessionStore) Create() (*Session, error) {
|
||||
session, err := NewSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.sessions[session.ID] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Get retrieves a session by ID. Returns nil if not found or expired.
|
||||
func (s *SessionStore) Get(id string) *Session {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
session, ok := s.sessions[id]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
session.mu.RLock()
|
||||
expired := time.Since(session.LastActivity) > s.ttl
|
||||
session.mu.RUnlock()
|
||||
|
||||
if expired {
|
||||
return nil
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// Delete removes a session from the store.
|
||||
func (s *SessionStore) Delete(id string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
session, ok := s.sessions[id]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
session.Close()
|
||||
delete(s.sessions, id)
|
||||
return true
|
||||
}
|
||||
|
||||
// Count returns the number of active sessions.
|
||||
func (s *SessionStore) Count() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.sessions)
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine and waits for it to finish.
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopClean)
|
||||
<-s.cleanDone
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes expired sessions.
|
||||
func (s *SessionStore) cleanupLoop() {
|
||||
defer close(s.cleanDone)
|
||||
|
||||
ticker := time.NewTicker(s.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopClean:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions.
|
||||
func (s *SessionStore) cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range s.sessions {
|
||||
session.mu.RLock()
|
||||
expired := now.Sub(session.LastActivity) > s.ttl
|
||||
session.mu.RUnlock()
|
||||
|
||||
if expired {
|
||||
session.Close()
|
||||
delete(s.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateSessionID generates a cryptographically secure random session ID.
|
||||
func generateSessionID() (string, error) {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
267
internal/mcp/session_test.go
Normal file
267
internal/mcp/session_test.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewSession(t *testing.T) {
|
||||
session, err := NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
if session.ID == "" {
|
||||
t.Error("Session ID should not be empty")
|
||||
}
|
||||
if len(session.ID) != 32 {
|
||||
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
|
||||
}
|
||||
if session.Initialized {
|
||||
t.Error("New session should not be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionTouch(t *testing.T) {
|
||||
session, _ := NewSession()
|
||||
originalActivity := session.LastActivity
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
session.Touch()
|
||||
|
||||
if !session.LastActivity.After(originalActivity) {
|
||||
t.Error("Touch should update LastActivity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionInitialized(t *testing.T) {
|
||||
session, _ := NewSession()
|
||||
|
||||
if session.IsInitialized() {
|
||||
t.Error("New session should not be initialized")
|
||||
}
|
||||
|
||||
session.SetInitialized()
|
||||
|
||||
if !session.IsInitialized() {
|
||||
t.Error("Session should be initialized after SetInitialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionNotifications(t *testing.T) {
|
||||
session, _ := NewSession()
|
||||
defer session.Close()
|
||||
|
||||
notification := &Response{JSONRPC: "2.0", ID: 1}
|
||||
|
||||
if !session.SendNotification(notification) {
|
||||
t.Error("SendNotification should return true on success")
|
||||
}
|
||||
|
||||
select {
|
||||
case received := <-session.Notifications():
|
||||
if received.ID != notification.ID {
|
||||
t.Error("Received notification should match sent")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Should receive notification")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreCreate(t *testing.T) {
|
||||
store := NewSessionStore(30 * time.Minute)
|
||||
defer store.Stop()
|
||||
|
||||
session, err := store.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
if store.Count() != 1 {
|
||||
t.Errorf("Store should have 1 session, got %d", store.Count())
|
||||
}
|
||||
|
||||
// Verify we can retrieve it
|
||||
retrieved := store.Get(session.ID)
|
||||
if retrieved == nil {
|
||||
t.Error("Should be able to retrieve created session")
|
||||
}
|
||||
if retrieved.ID != session.ID {
|
||||
t.Error("Retrieved session ID should match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreGet(t *testing.T) {
|
||||
store := NewSessionStore(30 * time.Minute)
|
||||
defer store.Stop()
|
||||
|
||||
// Get non-existent session
|
||||
if store.Get("nonexistent") != nil {
|
||||
t.Error("Should return nil for non-existent session")
|
||||
}
|
||||
|
||||
// Create and retrieve
|
||||
session, _ := store.Create()
|
||||
retrieved := store.Get(session.ID)
|
||||
if retrieved == nil {
|
||||
t.Error("Should find created session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreDelete(t *testing.T) {
|
||||
store := NewSessionStore(30 * time.Minute)
|
||||
defer store.Stop()
|
||||
|
||||
session, _ := store.Create()
|
||||
if store.Count() != 1 {
|
||||
t.Error("Should have 1 session after create")
|
||||
}
|
||||
|
||||
if !store.Delete(session.ID) {
|
||||
t.Error("Delete should return true for existing session")
|
||||
}
|
||||
|
||||
if store.Count() != 0 {
|
||||
t.Error("Should have 0 sessions after delete")
|
||||
}
|
||||
|
||||
if store.Delete(session.ID) {
|
||||
t.Error("Delete should return false for non-existent session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreTTLExpiration(t *testing.T) {
|
||||
ttl := 50 * time.Millisecond
|
||||
store := NewSessionStore(ttl)
|
||||
defer store.Stop()
|
||||
|
||||
session, _ := store.Create()
|
||||
|
||||
// Should be retrievable immediately
|
||||
if store.Get(session.ID) == nil {
|
||||
t.Error("Session should be retrievable immediately")
|
||||
}
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(ttl + 10*time.Millisecond)
|
||||
|
||||
// Should not be retrievable after TTL
|
||||
if store.Get(session.ID) != nil {
|
||||
t.Error("Expired session should not be retrievable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreTTLRefresh(t *testing.T) {
|
||||
ttl := 100 * time.Millisecond
|
||||
store := NewSessionStore(ttl)
|
||||
defer store.Stop()
|
||||
|
||||
session, _ := store.Create()
|
||||
|
||||
// Touch the session before TTL expires
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
session.Touch()
|
||||
|
||||
// Wait past original TTL but not past refreshed TTL
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Should still be retrievable because we touched it
|
||||
if store.Get(session.ID) == nil {
|
||||
t.Error("Touched session should still be retrievable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreCleanup(t *testing.T) {
|
||||
ttl := 50 * time.Millisecond
|
||||
store := NewSessionStore(ttl)
|
||||
defer store.Stop()
|
||||
|
||||
// Create multiple sessions
|
||||
for i := 0; i < 5; i++ {
|
||||
store.Create()
|
||||
}
|
||||
|
||||
if store.Count() != 5 {
|
||||
t.Errorf("Should have 5 sessions, got %d", store.Count())
|
||||
}
|
||||
|
||||
// Wait for cleanup to run (runs at ttl/2 intervals)
|
||||
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
|
||||
|
||||
// All sessions should be cleaned up
|
||||
if store.Count() != 0 {
|
||||
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStoreConcurrency(t *testing.T) {
|
||||
store := NewSessionStore(30 * time.Minute)
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sessionIDs := make(chan string, 100)
|
||||
|
||||
// Create sessions concurrently
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
session, err := store.Create()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to create session: %v", err)
|
||||
return
|
||||
}
|
||||
sessionIDs <- session.ID
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(sessionIDs)
|
||||
|
||||
// Verify all sessions were created
|
||||
if store.Count() != 50 {
|
||||
t.Errorf("Should have 50 sessions, got %d", store.Count())
|
||||
}
|
||||
|
||||
// Read and delete concurrently
|
||||
var ids []string
|
||||
for id := range sessionIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
wg.Add(2)
|
||||
go func(id string) {
|
||||
defer wg.Done()
|
||||
store.Get(id)
|
||||
}(id)
|
||||
go func(id string) {
|
||||
defer wg.Done()
|
||||
store.Delete(id)
|
||||
}(id)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestGenerateSessionID(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
|
||||
// Generate 1000 IDs and ensure uniqueness
|
||||
for i := 0; i < 1000; i++ {
|
||||
id, err := generateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate session ID: %v", err)
|
||||
}
|
||||
|
||||
if len(id) != 32 {
|
||||
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
|
||||
}
|
||||
|
||||
if ids[id] {
|
||||
t.Error("Generated duplicate session ID")
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
10
internal/mcp/transport.go
Normal file
10
internal/mcp/transport.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package mcp
|
||||
|
||||
import "context"
|
||||
|
||||
// Transport defines the interface for MCP server transports.
|
||||
type Transport interface {
|
||||
// Run starts the transport and blocks until the context is cancelled
|
||||
// or an error occurs.
|
||||
Run(ctx context.Context) error
|
||||
}
|
||||
354
internal/mcp/transport_http.go
Normal file
354
internal/mcp/transport_http.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPConfig configures the HTTP transport.
|
||||
type HTTPConfig struct {
|
||||
Address string // Listen address (e.g., "127.0.0.1:8080")
|
||||
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
||||
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
||||
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
||||
TLSCertFile string // TLS certificate file (optional)
|
||||
TLSKeyFile string // TLS key file (optional)
|
||||
}
|
||||
|
||||
// HTTPTransport implements the MCP Streamable HTTP transport.
|
||||
type HTTPTransport struct {
|
||||
server *Server
|
||||
config HTTPConfig
|
||||
sessions *SessionStore
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTP transport.
|
||||
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
||||
if config.Address == "" {
|
||||
config.Address = "127.0.0.1:8080"
|
||||
}
|
||||
if config.Endpoint == "" {
|
||||
config.Endpoint = "/mcp"
|
||||
}
|
||||
if config.SessionTTL == 0 {
|
||||
config.SessionTTL = 30 * time.Minute
|
||||
}
|
||||
|
||||
return &HTTPTransport{
|
||||
server: server,
|
||||
config: config,
|
||||
sessions: NewSessionStore(config.SessionTTL),
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the HTTP server and blocks until the context is cancelled.
|
||||
func (t *HTTPTransport) Run(ctx context.Context) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: t.config.Address,
|
||||
Handler: mux,
|
||||
BaseContext: func(l net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
|
||||
// Graceful shutdown on context cancellation
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
t.server.logger.Println("Shutting down HTTP server...")
|
||||
t.sessions.Stop()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
t.server.logger.Printf("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
|
||||
|
||||
var err error
|
||||
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
|
||||
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
|
||||
} else {
|
||||
err = httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
if err == http.ErrServerClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMCP routes requests based on HTTP method.
|
||||
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate Origin header
|
||||
if !t.isOriginAllowed(r) {
|
||||
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
t.handlePost(w, r)
|
||||
case http.MethodGet:
|
||||
t.handleGet(w, r)
|
||||
case http.MethodDelete:
|
||||
t.handleDelete(w, r)
|
||||
case http.MethodOptions:
|
||||
t.handleOptions(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePost handles JSON-RPC requests.
|
||||
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
||||
// Read request body
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request to check method
|
||||
var req Request
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(Response{
|
||||
JSONRPC: "2.0",
|
||||
Error: &Error{
|
||||
Code: ParseError,
|
||||
Message: "Parse error",
|
||||
Data: err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle initialize request - create session
|
||||
if req.Method == MethodInitialize {
|
||||
t.handleInitialize(w, r, &req)
|
||||
return
|
||||
}
|
||||
|
||||
// All other requests require a valid session
|
||||
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session := t.sessions.Get(sessionID)
|
||||
if session == nil {
|
||||
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Update session activity
|
||||
session.Touch()
|
||||
|
||||
// Handle notifications (no response expected)
|
||||
if req.Method == MethodInitialized {
|
||||
session.SetInitialized()
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this is a notification (no ID)
|
||||
if req.ID == nil {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the request
|
||||
resp := t.server.HandleRequest(r.Context(), &req)
|
||||
if resp == nil {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleInitialize handles the initialize request and creates a new session.
|
||||
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
|
||||
// Create a new session
|
||||
session, err := t.sessions.Create()
|
||||
if err != nil {
|
||||
t.server.logger.Printf("Failed to create session: %v", err)
|
||||
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Process initialize request
|
||||
resp := t.server.HandleRequest(r.Context(), req)
|
||||
if resp == nil {
|
||||
t.sessions.Delete(session.ID)
|
||||
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// If initialize failed, clean up session
|
||||
if resp.Error != nil {
|
||||
t.sessions.Delete(session.ID)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Mcp-Session-Id", session.ID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
// handleGet handles SSE stream for server-initiated notifications.
|
||||
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session := t.sessions.Get(sessionID)
|
||||
if session == nil {
|
||||
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client accepts SSE
|
||||
accept := r.Header.Get("Accept")
|
||||
if !strings.Contains(accept, "text/event-stream") {
|
||||
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
|
||||
return
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
// Flush headers
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Stream notifications
|
||||
ctx := r.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case notification, ok := <-session.Notifications():
|
||||
if !ok {
|
||||
// Session closed
|
||||
return
|
||||
}
|
||||
|
||||
data, err := json.Marshal(notification)
|
||||
if err != nil {
|
||||
t.server.logger.Printf("Failed to marshal notification: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Write SSE event
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
|
||||
// Touch session to keep it alive
|
||||
session.Touch()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleDelete terminates a session.
|
||||
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if t.sessions.Delete(sessionID) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
} else {
|
||||
http.Error(w, "Session not found", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOptions handles CORS preflight requests.
|
||||
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin != "" && t.isOriginAllowed(r) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
|
||||
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
|
||||
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if the request origin is allowed.
|
||||
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// No Origin header (same-origin request) is always allowed
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If no allowed origins configured, only allow localhost
|
||||
if len(t.config.AllowedOrigins) == 0 {
|
||||
return isLocalhostOrigin(origin)
|
||||
}
|
||||
|
||||
// Check against allowed origins
|
||||
for _, allowed := range t.config.AllowedOrigins {
|
||||
if allowed == "*" || allowed == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLocalhostOrigin checks if the origin is a localhost address.
|
||||
func isLocalhostOrigin(origin string) bool {
|
||||
origin = strings.ToLower(origin)
|
||||
|
||||
// Check for localhost patterns (must be followed by :, /, or end of string)
|
||||
localhostPatterns := []string{
|
||||
"http://localhost",
|
||||
"https://localhost",
|
||||
"http://127.0.0.1",
|
||||
"https://127.0.0.1",
|
||||
"http://[::1]",
|
||||
"https://[::1]",
|
||||
}
|
||||
|
||||
for _, pattern := range localhostPatterns {
|
||||
if origin == pattern {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
513
internal/mcp/transport_http_test.go
Normal file
513
internal/mcp/transport_http_test.go
Normal file
@@ -0,0 +1,513 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// testHTTPTransport creates a transport with a test server
|
||||
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
|
||||
// Use a mock store
|
||||
server := NewServer(nil, log.New(io.Discard, "", 0))
|
||||
|
||||
if config.SessionTTL == 0 {
|
||||
config.SessionTTL = 30 * time.Minute
|
||||
}
|
||||
transport := NewHTTPTransport(server, config)
|
||||
|
||||
// Create test server
|
||||
mux := http.NewServeMux()
|
||||
endpoint := config.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/mcp"
|
||||
}
|
||||
mux.HandleFunc(endpoint, transport.handleMCP)
|
||||
|
||||
ts := httptest.NewServer(mux)
|
||||
t.Cleanup(func() {
|
||||
ts.Close()
|
||||
transport.sessions.Stop()
|
||||
})
|
||||
|
||||
return transport, ts
|
||||
}
|
||||
|
||||
func TestHTTPTransportInitialize(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send initialize request
|
||||
initReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodInitialize,
|
||||
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||
}
|
||||
body, _ := json.Marshal(initReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check session ID header
|
||||
sessionID := resp.Header.Get("Mcp-Session-Id")
|
||||
if sessionID == "" {
|
||||
t.Error("Expected Mcp-Session-Id header")
|
||||
}
|
||||
if len(sessionID) != 32 {
|
||||
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var initResp Response
|
||||
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if initResp.Error != nil {
|
||||
t.Errorf("Initialize failed: %v", initResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSessionRequired(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send tools/list without session
|
||||
listReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodToolsList,
|
||||
}
|
||||
body, _ := json.Marshal(listReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportInvalidSession(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send request with invalid session
|
||||
listReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodToolsList,
|
||||
}
|
||||
body, _ := json.Marshal(listReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportValidSession(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Create session manually
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Send tools/list with valid session
|
||||
listReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodToolsList,
|
||||
}
|
||||
body, _ := json.Marshal(listReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportNotificationAccepted(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Send notification (no ID)
|
||||
notification := Request{
|
||||
JSONRPC: "2.0",
|
||||
Method: MethodInitialized,
|
||||
}
|
||||
body, _ := json.Marshal(notification)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusAccepted {
|
||||
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Verify session is marked as initialized
|
||||
if !session.IsInitialized() {
|
||||
t.Error("Session should be marked as initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportDeleteSession(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Delete session
|
||||
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Verify session is gone
|
||||
if transport.sessions.Get(session.ID) != nil {
|
||||
t.Error("Session should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", "nonexistent")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportOriginValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedOrigins []string
|
||||
origin string
|
||||
expectAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "no origin header",
|
||||
allowedOrigins: nil,
|
||||
origin: "",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "localhost allowed by default",
|
||||
allowedOrigins: nil,
|
||||
origin: "http://localhost:3000",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 allowed by default",
|
||||
allowedOrigins: nil,
|
||||
origin: "http://127.0.0.1:8080",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "external origin blocked by default",
|
||||
allowedOrigins: nil,
|
||||
origin: "http://evil.com",
|
||||
expectAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "explicit allow",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
origin: "http://example.com",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "explicit allow wildcard",
|
||||
allowedOrigins: []string{"*"},
|
||||
origin: "http://anything.com",
|
||||
expectAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "not in allowed list",
|
||||
allowedOrigins: []string{"http://example.com"},
|
||||
origin: "http://other.com",
|
||||
expectAllowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||
AllowedOrigins: tt.allowedOrigins,
|
||||
})
|
||||
|
||||
// Use initialize since it doesn't require a session
|
||||
initReq := Request{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: MethodInitialize,
|
||||
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||
}
|
||||
body, _ := json.Marshal(initReq)
|
||||
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if tt.origin != "" {
|
||||
req.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
|
||||
t.Error("Expected request to be allowed but was forbidden")
|
||||
}
|
||||
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// GET without Accept: text/event-stream
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotAcceptable {
|
||||
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportSSEStream(t *testing.T) {
|
||||
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
session, _ := transport.sessions.Create()
|
||||
|
||||
// Start SSE stream in goroutine
|
||||
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != "text/event-stream" {
|
||||
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
notification := &Response{
|
||||
JSONRPC: "2.0",
|
||||
ID: 42,
|
||||
Result: map[string]string{"test": "data"},
|
||||
}
|
||||
session.SendNotification(notification)
|
||||
|
||||
// Read the SSE event
|
||||
buf := make([]byte, 1024)
|
||||
n, err := resp.Body.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Failed to read SSE event: %v", err)
|
||||
}
|
||||
|
||||
data := string(buf[:n])
|
||||
if !strings.HasPrefix(data, "data: ") {
|
||||
t.Errorf("Expected SSE data event, got: %s", data)
|
||||
}
|
||||
|
||||
// Parse the JSON from the SSE event
|
||||
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
|
||||
var received Response
|
||||
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
|
||||
t.Fatalf("Failed to parse notification JSON: %v", err)
|
||||
}
|
||||
|
||||
// JSON unmarshal converts numbers to float64, so compare as float64
|
||||
receivedID, ok := received.ID.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("Expected numeric ID, got %T", received.ID)
|
||||
}
|
||||
if int(receivedID) != 42 {
|
||||
t.Errorf("Expected notification ID 42, got %v", receivedID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportParseError(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
// Send invalid JSON
|
||||
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var jsonResp Response
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if jsonResp.Error == nil {
|
||||
t.Error("Expected JSON-RPC error for parse error")
|
||||
}
|
||||
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
|
||||
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||
|
||||
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected 405, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportOptionsRequest(t *testing.T) {
|
||||
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||
AllowedOrigins: []string{"http://example.com"},
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected 204, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||
t.Error("Expected CORS origin header")
|
||||
}
|
||||
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
|
||||
t.Error("Expected CORS methods header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalhostOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
origin string
|
||||
expected bool
|
||||
}{
|
||||
{"http://localhost", true},
|
||||
{"http://localhost:3000", true},
|
||||
{"https://localhost", true},
|
||||
{"https://localhost:8443", true},
|
||||
{"http://127.0.0.1", true},
|
||||
{"http://127.0.0.1:8080", true},
|
||||
{"https://127.0.0.1", true},
|
||||
{"http://[::1]", true},
|
||||
{"http://[::1]:8080", true},
|
||||
{"https://[::1]", true},
|
||||
{"http://example.com", false},
|
||||
{"https://example.com", false},
|
||||
{"http://localhost.evil.com", false},
|
||||
{"http://192.168.1.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.origin, func(t *testing.T) {
|
||||
result := isLocalhostOrigin(tt.origin)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
63
internal/mcp/transport_stdio.go
Normal file
63
internal/mcp/transport_stdio.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// StdioTransport implements the MCP protocol over STDIO using line-delimited JSON-RPC.
|
||||
type StdioTransport struct {
|
||||
server *Server
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
// NewStdioTransport creates a new STDIO transport.
|
||||
func NewStdioTransport(server *Server, r io.Reader, w io.Writer) *StdioTransport {
|
||||
return &StdioTransport{
|
||||
server: server,
|
||||
reader: r,
|
||||
writer: w,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the STDIO transport, reading line-delimited JSON-RPC from the reader
|
||||
// and writing responses to the writer.
|
||||
func (t *StdioTransport) Run(ctx context.Context) error {
|
||||
scanner := bufio.NewScanner(t.reader)
|
||||
encoder := json.NewEncoder(t.writer)
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := t.server.HandleMessage(ctx, line)
|
||||
if err != nil {
|
||||
t.server.logger.Printf("Error handling message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
return fmt.Errorf("failed to write response: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("scanner error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user