From cbe55d645651ecbe1035fc64fa6d96659756da52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Tue, 3 Feb 2026 22:02:40 +0100 Subject: [PATCH] feat: add Streamable HTTP transport support Add support for running the MCP server over HTTP with Server-Sent Events (SSE) using the MCP Streamable HTTP specification, alongside the existing STDIO transport. New features: - Transport abstraction with Transport interface - HTTP transport with session management - SSE support for server-initiated notifications - CORS security with configurable allowed origins - Optional TLS support - CLI flags for HTTP configuration (--transport, --http-address, etc.) - NixOS module options for HTTP transport The HTTP transport implements: - POST /mcp: JSON-RPC requests with session management - GET /mcp: SSE stream for server notifications - DELETE /mcp: Session termination - Origin validation (localhost-only by default) Co-Authored-By: Claude Opus 4.5 --- cmd/nixos-options/main.go | 66 +++- internal/mcp/server.go | 68 ++-- internal/mcp/session.go | 207 +++++++++++ internal/mcp/session_test.go | 267 +++++++++++++++ internal/mcp/transport.go | 10 + internal/mcp/transport_http.go | 354 +++++++++++++++++++ internal/mcp/transport_http_test.go | 513 ++++++++++++++++++++++++++++ internal/mcp/transport_stdio.go | 63 ++++ nix/module.nix | 81 ++++- 9 files changed, 1575 insertions(+), 54 deletions(-) create mode 100644 internal/mcp/session.go create mode 100644 internal/mcp/session_test.go create mode 100644 internal/mcp/transport.go create mode 100644 internal/mcp/transport_http.go create mode 100644 internal/mcp/transport_http_test.go create mode 100644 internal/mcp/transport_stdio.go diff --git a/cmd/nixos-options/main.go b/cmd/nixos-options/main.go index 9aec61f..b9e7451 100644 --- a/cmd/nixos-options/main.go +++ b/cmd/nixos-options/main.go @@ -5,7 +5,10 @@ import ( "fmt" "log" "os" + "os/signal" "strings" + "syscall" + "time" "github.com/urfave/cli/v2" @@ -36,7 +39,42 @@ func main() { Commands: []*cli.Command{ { Name: "serve", - Usage: "Run MCP server (stdio)", + Usage: "Run MCP server", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "transport", + Aliases: []string{"t"}, + Usage: "Transport type: 'stdio' or 'http'", + Value: "stdio", + }, + &cli.StringFlag{ + Name: "http-address", + Usage: "HTTP listen address", + Value: "127.0.0.1:8080", + }, + &cli.StringFlag{ + Name: "http-endpoint", + Usage: "HTTP endpoint path", + Value: "/mcp", + }, + &cli.StringSliceFlag{ + Name: "allowed-origins", + Usage: "Allowed Origin headers for CORS (can be specified multiple times)", + }, + &cli.StringFlag{ + Name: "tls-cert", + Usage: "TLS certificate file", + }, + &cli.StringFlag{ + Name: "tls-key", + Usage: "TLS key file", + }, + &cli.DurationFlag{ + Name: "session-ttl", + Usage: "Session TTL for HTTP transport", + Value: 30 * time.Minute, + }, + }, Action: func(c *cli.Context) error { return runServe(c) }, @@ -145,7 +183,8 @@ func openStore(connStr string) (database.Store, error) { } func runServe(c *cli.Context) error { - ctx := context.Background() + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() store, err := openStore(c.String("database")) if err != nil { @@ -163,8 +202,27 @@ func runServe(c *cli.Context) error { indexer := nixos.NewIndexer(store) server.RegisterHandlers(indexer) - logger.Println("Starting MCP server on stdio...") - return server.Run(ctx, os.Stdin, os.Stdout) + transport := c.String("transport") + switch transport { + case "stdio": + logger.Println("Starting MCP server on stdio...") + return server.Run(ctx, os.Stdin, os.Stdout) + + case "http": + config := mcp.HTTPConfig{ + Address: c.String("http-address"), + Endpoint: c.String("http-endpoint"), + AllowedOrigins: c.StringSlice("allowed-origins"), + SessionTTL: c.Duration("session-ttl"), + TLSCertFile: c.String("tls-cert"), + TLSKeyFile: c.String("tls-key"), + } + httpTransport := mcp.NewHTTPTransport(server, config) + return httpTransport.Run(ctx) + + default: + return fmt.Errorf("unknown transport: %s (use 'stdio' or 'http')", transport) + } } func runIndex(c *cli.Context, revision string, indexFiles bool, force bool) error { diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 22874d6..743f4c1 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -1,7 +1,6 @@ package mcp import ( - "bufio" "context" "encoding/json" "fmt" @@ -11,7 +10,7 @@ import ( "git.t-juice.club/torjus/labmcp/internal/database" ) -// Server is an MCP server that handles JSON-RPC requests over stdio. +// Server is an MCP server that handles JSON-RPC requests. type Server struct { store database.Store tools map[string]ToolHandler @@ -41,53 +40,34 @@ func (s *Server) registerTools() { // Tools will be implemented in handlers.go } -// Run starts the server, reading from r and writing to w. +// Run starts the server using STDIO transport (backward compatibility). func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error { - scanner := bufio.NewScanner(r) - encoder := json.NewEncoder(w) + transport := NewStdioTransport(s, r, w) + return transport.Run(ctx) +} - for scanner.Scan() { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - line := scanner.Bytes() - if len(line) == 0 { - continue - } - - var req Request - if err := json.Unmarshal(line, &req); err != nil { - s.logger.Printf("Failed to parse request: %v", err) - resp := Response{ - JSONRPC: "2.0", - Error: &Error{ - Code: ParseError, - Message: "Parse error", - Data: err.Error(), - }, - } - if err := encoder.Encode(resp); err != nil { - return fmt.Errorf("failed to write response: %w", err) - } - continue - } - - resp := s.handleRequest(ctx, &req) - if resp != nil { - if err := encoder.Encode(resp); err != nil { - return fmt.Errorf("failed to write response: %w", err) - } - } +// HandleMessage parses a JSON-RPC message and returns the response. +// Returns (nil, nil) for notifications that don't require a response. +func (s *Server) HandleMessage(ctx context.Context, data []byte) (*Response, error) { + var req Request + if err := json.Unmarshal(data, &req); err != nil { + return &Response{ + JSONRPC: "2.0", + Error: &Error{ + Code: ParseError, + Message: "Parse error", + Data: err.Error(), + }, + }, nil } - if err := scanner.Err(); err != nil { - return fmt.Errorf("scanner error: %w", err) - } + return s.HandleRequest(ctx, &req), nil +} - return nil +// HandleRequest processes a single request and returns a response. +// Returns nil for notifications that don't require a response. +func (s *Server) HandleRequest(ctx context.Context, req *Request) *Response { + return s.handleRequest(ctx, req) } // handleRequest processes a single request and returns a response. diff --git a/internal/mcp/session.go b/internal/mcp/session.go new file mode 100644 index 0000000..0f6e238 --- /dev/null +++ b/internal/mcp/session.go @@ -0,0 +1,207 @@ +package mcp + +import ( + "crypto/rand" + "encoding/hex" + "sync" + "time" +) + +// Session represents an MCP client session. +type Session struct { + ID string + CreatedAt time.Time + LastActivity time.Time + Initialized bool + + // notifications is a channel for server-initiated notifications. + notifications chan *Response + + mu sync.RWMutex +} + +// NewSession creates a new session with a cryptographically secure random ID. +func NewSession() (*Session, error) { + id, err := generateSessionID() + if err != nil { + return nil, err + } + + now := time.Now() + return &Session{ + ID: id, + CreatedAt: now, + LastActivity: now, + notifications: make(chan *Response, 100), + }, nil +} + +// Touch updates the session's last activity time. +func (s *Session) Touch() { + s.mu.Lock() + defer s.mu.Unlock() + s.LastActivity = time.Now() +} + +// SetInitialized marks the session as initialized. +func (s *Session) SetInitialized() { + s.mu.Lock() + defer s.mu.Unlock() + s.Initialized = true +} + +// IsInitialized returns whether the session has been initialized. +func (s *Session) IsInitialized() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.Initialized +} + +// Notifications returns the channel for server-initiated notifications. +func (s *Session) Notifications() <-chan *Response { + return s.notifications +} + +// SendNotification sends a notification to the session. +// Returns false if the channel is full. +func (s *Session) SendNotification(notification *Response) bool { + select { + case s.notifications <- notification: + return true + default: + return false + } +} + +// Close closes the session's notification channel. +func (s *Session) Close() { + close(s.notifications) +} + +// SessionStore manages active sessions with TTL-based cleanup. +type SessionStore struct { + sessions map[string]*Session + ttl time.Duration + mu sync.RWMutex + stopClean chan struct{} + cleanDone chan struct{} +} + +// NewSessionStore creates a new session store with the given TTL. +func NewSessionStore(ttl time.Duration) *SessionStore { + s := &SessionStore{ + sessions: make(map[string]*Session), + ttl: ttl, + stopClean: make(chan struct{}), + cleanDone: make(chan struct{}), + } + go s.cleanupLoop() + return s +} + +// Create creates a new session and adds it to the store. +func (s *SessionStore) Create() (*Session, error) { + session, err := NewSession() + if err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + s.sessions[session.ID] = session + return session, nil +} + +// Get retrieves a session by ID. Returns nil if not found or expired. +func (s *SessionStore) Get(id string) *Session { + s.mu.RLock() + defer s.mu.RUnlock() + + session, ok := s.sessions[id] + if !ok { + return nil + } + + // Check if expired + session.mu.RLock() + expired := time.Since(session.LastActivity) > s.ttl + session.mu.RUnlock() + + if expired { + return nil + } + + return session +} + +// Delete removes a session from the store. +func (s *SessionStore) Delete(id string) bool { + s.mu.Lock() + defer s.mu.Unlock() + + session, ok := s.sessions[id] + if !ok { + return false + } + + session.Close() + delete(s.sessions, id) + return true +} + +// Count returns the number of active sessions. +func (s *SessionStore) Count() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.sessions) +} + +// Stop stops the cleanup goroutine and waits for it to finish. +func (s *SessionStore) Stop() { + close(s.stopClean) + <-s.cleanDone +} + +// cleanupLoop periodically removes expired sessions. +func (s *SessionStore) cleanupLoop() { + defer close(s.cleanDone) + + ticker := time.NewTicker(s.ttl / 2) + defer ticker.Stop() + + for { + select { + case <-s.stopClean: + return + case <-ticker.C: + s.cleanup() + } + } +} + +// cleanup removes expired sessions. +func (s *SessionStore) cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for id, session := range s.sessions { + session.mu.RLock() + expired := now.Sub(session.LastActivity) > s.ttl + session.mu.RUnlock() + + if expired { + session.Close() + delete(s.sessions, id) + } + } +} + +// generateSessionID generates a cryptographically secure random session ID. +func generateSessionID() (string, error) { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} diff --git a/internal/mcp/session_test.go b/internal/mcp/session_test.go new file mode 100644 index 0000000..3c937fa --- /dev/null +++ b/internal/mcp/session_test.go @@ -0,0 +1,267 @@ +package mcp + +import ( + "sync" + "testing" + "time" +) + +func TestNewSession(t *testing.T) { + session, err := NewSession() + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if session.ID == "" { + t.Error("Session ID should not be empty") + } + if len(session.ID) != 32 { + t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID)) + } + if session.Initialized { + t.Error("New session should not be initialized") + } +} + +func TestSessionTouch(t *testing.T) { + session, _ := NewSession() + originalActivity := session.LastActivity + + time.Sleep(10 * time.Millisecond) + session.Touch() + + if !session.LastActivity.After(originalActivity) { + t.Error("Touch should update LastActivity") + } +} + +func TestSessionInitialized(t *testing.T) { + session, _ := NewSession() + + if session.IsInitialized() { + t.Error("New session should not be initialized") + } + + session.SetInitialized() + + if !session.IsInitialized() { + t.Error("Session should be initialized after SetInitialized") + } +} + +func TestSessionNotifications(t *testing.T) { + session, _ := NewSession() + defer session.Close() + + notification := &Response{JSONRPC: "2.0", ID: 1} + + if !session.SendNotification(notification) { + t.Error("SendNotification should return true on success") + } + + select { + case received := <-session.Notifications(): + if received.ID != notification.ID { + t.Error("Received notification should match sent") + } + case <-time.After(100 * time.Millisecond): + t.Error("Should receive notification") + } +} + +func TestSessionStoreCreate(t *testing.T) { + store := NewSessionStore(30 * time.Minute) + defer store.Stop() + + session, err := store.Create() + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if store.Count() != 1 { + t.Errorf("Store should have 1 session, got %d", store.Count()) + } + + // Verify we can retrieve it + retrieved := store.Get(session.ID) + if retrieved == nil { + t.Error("Should be able to retrieve created session") + } + if retrieved.ID != session.ID { + t.Error("Retrieved session ID should match") + } +} + +func TestSessionStoreGet(t *testing.T) { + store := NewSessionStore(30 * time.Minute) + defer store.Stop() + + // Get non-existent session + if store.Get("nonexistent") != nil { + t.Error("Should return nil for non-existent session") + } + + // Create and retrieve + session, _ := store.Create() + retrieved := store.Get(session.ID) + if retrieved == nil { + t.Error("Should find created session") + } +} + +func TestSessionStoreDelete(t *testing.T) { + store := NewSessionStore(30 * time.Minute) + defer store.Stop() + + session, _ := store.Create() + if store.Count() != 1 { + t.Error("Should have 1 session after create") + } + + if !store.Delete(session.ID) { + t.Error("Delete should return true for existing session") + } + + if store.Count() != 0 { + t.Error("Should have 0 sessions after delete") + } + + if store.Delete(session.ID) { + t.Error("Delete should return false for non-existent session") + } +} + +func TestSessionStoreTTLExpiration(t *testing.T) { + ttl := 50 * time.Millisecond + store := NewSessionStore(ttl) + defer store.Stop() + + session, _ := store.Create() + + // Should be retrievable immediately + if store.Get(session.ID) == nil { + t.Error("Session should be retrievable immediately") + } + + // Wait for expiration + time.Sleep(ttl + 10*time.Millisecond) + + // Should not be retrievable after TTL + if store.Get(session.ID) != nil { + t.Error("Expired session should not be retrievable") + } +} + +func TestSessionStoreTTLRefresh(t *testing.T) { + ttl := 100 * time.Millisecond + store := NewSessionStore(ttl) + defer store.Stop() + + session, _ := store.Create() + + // Touch the session before TTL expires + time.Sleep(60 * time.Millisecond) + session.Touch() + + // Wait past original TTL but not past refreshed TTL + time.Sleep(60 * time.Millisecond) + + // Should still be retrievable because we touched it + if store.Get(session.ID) == nil { + t.Error("Touched session should still be retrievable") + } +} + +func TestSessionStoreCleanup(t *testing.T) { + ttl := 50 * time.Millisecond + store := NewSessionStore(ttl) + defer store.Stop() + + // Create multiple sessions + for i := 0; i < 5; i++ { + store.Create() + } + + if store.Count() != 5 { + t.Errorf("Should have 5 sessions, got %d", store.Count()) + } + + // Wait for cleanup to run (runs at ttl/2 intervals) + time.Sleep(ttl + ttl/2 + 10*time.Millisecond) + + // All sessions should be cleaned up + if store.Count() != 0 { + t.Errorf("All sessions should be cleaned up, got %d", store.Count()) + } +} + +func TestSessionStoreConcurrency(t *testing.T) { + store := NewSessionStore(30 * time.Minute) + defer store.Stop() + + var wg sync.WaitGroup + sessionIDs := make(chan string, 100) + + // Create sessions concurrently + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + session, err := store.Create() + if err != nil { + t.Errorf("Failed to create session: %v", err) + return + } + sessionIDs <- session.ID + }() + } + + wg.Wait() + close(sessionIDs) + + // Verify all sessions were created + if store.Count() != 50 { + t.Errorf("Should have 50 sessions, got %d", store.Count()) + } + + // Read and delete concurrently + var ids []string + for id := range sessionIDs { + ids = append(ids, id) + } + + for _, id := range ids { + wg.Add(2) + go func(id string) { + defer wg.Done() + store.Get(id) + }(id) + go func(id string) { + defer wg.Done() + store.Delete(id) + }(id) + } + + wg.Wait() +} + +func TestGenerateSessionID(t *testing.T) { + ids := make(map[string]bool) + + // Generate 1000 IDs and ensure uniqueness + for i := 0; i < 1000; i++ { + id, err := generateSessionID() + if err != nil { + t.Fatalf("Failed to generate session ID: %v", err) + } + + if len(id) != 32 { + t.Errorf("Session ID should be 32 hex chars, got %d", len(id)) + } + + if ids[id] { + t.Error("Generated duplicate session ID") + } + ids[id] = true + } +} diff --git a/internal/mcp/transport.go b/internal/mcp/transport.go new file mode 100644 index 0000000..db35dbd --- /dev/null +++ b/internal/mcp/transport.go @@ -0,0 +1,10 @@ +package mcp + +import "context" + +// Transport defines the interface for MCP server transports. +type Transport interface { + // Run starts the transport and blocks until the context is cancelled + // or an error occurs. + Run(ctx context.Context) error +} diff --git a/internal/mcp/transport_http.go b/internal/mcp/transport_http.go new file mode 100644 index 0000000..5d21ac3 --- /dev/null +++ b/internal/mcp/transport_http.go @@ -0,0 +1,354 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" +) + +// HTTPConfig configures the HTTP transport. +type HTTPConfig struct { + Address string // Listen address (e.g., "127.0.0.1:8080") + Endpoint string // MCP endpoint path (e.g., "/mcp") + AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only) + SessionTTL time.Duration // Session TTL (default: 30 minutes) + TLSCertFile string // TLS certificate file (optional) + TLSKeyFile string // TLS key file (optional) +} + +// HTTPTransport implements the MCP Streamable HTTP transport. +type HTTPTransport struct { + server *Server + config HTTPConfig + sessions *SessionStore +} + +// NewHTTPTransport creates a new HTTP transport. +func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport { + if config.Address == "" { + config.Address = "127.0.0.1:8080" + } + if config.Endpoint == "" { + config.Endpoint = "/mcp" + } + if config.SessionTTL == 0 { + config.SessionTTL = 30 * time.Minute + } + + return &HTTPTransport{ + server: server, + config: config, + sessions: NewSessionStore(config.SessionTTL), + } +} + +// Run starts the HTTP server and blocks until the context is cancelled. +func (t *HTTPTransport) Run(ctx context.Context) error { + mux := http.NewServeMux() + mux.HandleFunc(t.config.Endpoint, t.handleMCP) + + httpServer := &http.Server{ + Addr: t.config.Address, + Handler: mux, + BaseContext: func(l net.Listener) context.Context { + return ctx + }, + } + + // Graceful shutdown on context cancellation + go func() { + <-ctx.Done() + t.server.logger.Println("Shutting down HTTP server...") + t.sessions.Stop() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := httpServer.Shutdown(shutdownCtx); err != nil { + t.server.logger.Printf("HTTP server shutdown error: %v", err) + } + }() + + t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint) + + var err error + if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" { + err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile) + } else { + err = httpServer.ListenAndServe() + } + + if err == http.ErrServerClosed { + return nil + } + return err +} + +// handleMCP routes requests based on HTTP method. +func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) { + // Validate Origin header + if !t.isOriginAllowed(r) { + http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden) + return + } + + switch r.Method { + case http.MethodPost: + t.handlePost(w, r) + case http.MethodGet: + t.handleGet(w, r) + case http.MethodDelete: + t.handleDelete(w, r) + case http.MethodOptions: + t.handleOptions(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handlePost handles JSON-RPC requests. +func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) { + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + + // Parse request to check method + var req Request + if err := json.Unmarshal(body, &req); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(Response{ + JSONRPC: "2.0", + Error: &Error{ + Code: ParseError, + Message: "Parse error", + Data: err.Error(), + }, + }) + return + } + + // Handle initialize request - create session + if req.Method == MethodInitialize { + t.handleInitialize(w, r, &req) + return + } + + // All other requests require a valid session + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + http.Error(w, "Session ID required", http.StatusBadRequest) + return + } + + session := t.sessions.Get(sessionID) + if session == nil { + http.Error(w, "Invalid or expired session", http.StatusNotFound) + return + } + + // Update session activity + session.Touch() + + // Handle notifications (no response expected) + if req.Method == MethodInitialized { + session.SetInitialized() + w.WriteHeader(http.StatusAccepted) + return + } + + // Check if this is a notification (no ID) + if req.ID == nil { + w.WriteHeader(http.StatusAccepted) + return + } + + // Process the request + resp := t.server.HandleRequest(r.Context(), &req) + if resp == nil { + w.WriteHeader(http.StatusAccepted) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) +} + +// handleInitialize handles the initialize request and creates a new session. +func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) { + // Create a new session + session, err := t.sessions.Create() + if err != nil { + t.server.logger.Printf("Failed to create session: %v", err) + http.Error(w, "Failed to create session", http.StatusInternalServerError) + return + } + + // Process initialize request + resp := t.server.HandleRequest(r.Context(), req) + if resp == nil { + t.sessions.Delete(session.ID) + http.Error(w, "Initialize failed: no response", http.StatusInternalServerError) + return + } + + // If initialize failed, clean up session + if resp.Error != nil { + t.sessions.Delete(session.ID) + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", session.ID) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) +} + +// handleGet handles SSE stream for server-initiated notifications. +func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + http.Error(w, "Session ID required", http.StatusBadRequest) + return + } + + session := t.sessions.Get(sessionID) + if session == nil { + http.Error(w, "Invalid or expired session", http.StatusNotFound) + return + } + + // Check if client accepts SSE + accept := r.Header.Get("Accept") + if !strings.Contains(accept, "text/event-stream") { + http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable) + return + } + + // Set SSE headers + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + // Flush headers + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + flusher.Flush() + + // Stream notifications + ctx := r.Context() + for { + select { + case <-ctx.Done(): + return + case notification, ok := <-session.Notifications(): + if !ok { + // Session closed + return + } + + data, err := json.Marshal(notification) + if err != nil { + t.server.logger.Printf("Failed to marshal notification: %v", err) + continue + } + + // Write SSE event + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + + // Touch session to keep it alive + session.Touch() + } + } +} + +// handleDelete terminates a session. +func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get("Mcp-Session-Id") + if sessionID == "" { + http.Error(w, "Session ID required", http.StatusBadRequest) + return + } + + if t.sessions.Delete(sessionID) { + w.WriteHeader(http.StatusNoContent) + } else { + http.Error(w, "Session not found", http.StatusNotFound) + } +} + +// handleOptions handles CORS preflight requests. +func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" && t.isOriginAllowed(r) { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id") + w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id") + w.Header().Set("Access-Control-Max-Age", "86400") + } + w.WriteHeader(http.StatusNoContent) +} + +// isOriginAllowed checks if the request origin is allowed. +func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool { + origin := r.Header.Get("Origin") + + // No Origin header (same-origin request) is always allowed + if origin == "" { + return true + } + + // If no allowed origins configured, only allow localhost + if len(t.config.AllowedOrigins) == 0 { + return isLocalhostOrigin(origin) + } + + // Check against allowed origins + for _, allowed := range t.config.AllowedOrigins { + if allowed == "*" || allowed == origin { + return true + } + } + + return false +} + +// isLocalhostOrigin checks if the origin is a localhost address. +func isLocalhostOrigin(origin string) bool { + origin = strings.ToLower(origin) + + // Check for localhost patterns (must be followed by :, /, or end of string) + localhostPatterns := []string{ + "http://localhost", + "https://localhost", + "http://127.0.0.1", + "https://127.0.0.1", + "http://[::1]", + "https://[::1]", + } + + for _, pattern := range localhostPatterns { + if origin == pattern { + return true + } + if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") { + return true + } + } + + return false +} diff --git a/internal/mcp/transport_http_test.go b/internal/mcp/transport_http_test.go new file mode 100644 index 0000000..06f4655 --- /dev/null +++ b/internal/mcp/transport_http_test.go @@ -0,0 +1,513 @@ +package mcp + +import ( + "bytes" + "encoding/json" + "io" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// testHTTPTransport creates a transport with a test server +func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) { + // Use a mock store + server := NewServer(nil, log.New(io.Discard, "", 0)) + + if config.SessionTTL == 0 { + config.SessionTTL = 30 * time.Minute + } + transport := NewHTTPTransport(server, config) + + // Create test server + mux := http.NewServeMux() + endpoint := config.Endpoint + if endpoint == "" { + endpoint = "/mcp" + } + mux.HandleFunc(endpoint, transport.handleMCP) + + ts := httptest.NewServer(mux) + t.Cleanup(func() { + ts.Close() + transport.sessions.Stop() + }) + + return transport, ts +} + +func TestHTTPTransportInitialize(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{}) + + // Send initialize request + initReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodInitialize, + Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), + } + body, _ := json.Marshal(initReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200, got %d", resp.StatusCode) + } + + // Check session ID header + sessionID := resp.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.Error("Expected Mcp-Session-Id header") + } + if len(sessionID) != 32 { + t.Errorf("Session ID should be 32 chars, got %d", len(sessionID)) + } + + // Check response body + var initResp Response + if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if initResp.Error != nil { + t.Errorf("Initialize failed: %v", initResp.Error) + } +} + +func TestHTTPTransportSessionRequired(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{}) + + // Send tools/list without session + listReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodToolsList, + } + body, _ := json.Marshal(listReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected 400 without session, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportInvalidSession(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{}) + + // Send request with invalid session + listReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodToolsList, + } + body, _ := json.Marshal(listReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", "invalid-session-id") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportValidSession(t *testing.T) { + transport, ts := testHTTPTransport(t, HTTPConfig{}) + + // Create session manually + session, _ := transport.sessions.Create() + + // Send tools/list with valid session + listReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodToolsList, + } + body, _ := json.Marshal(listReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", session.ID) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportNotificationAccepted(t *testing.T) { + transport, ts := testHTTPTransport(t, HTTPConfig{}) + + session, _ := transport.sessions.Create() + + // Send notification (no ID) + notification := Request{ + JSONRPC: "2.0", + Method: MethodInitialized, + } + body, _ := json.Marshal(notification) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", session.ID) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected 202 for notification, got %d", resp.StatusCode) + } + + // Verify session is marked as initialized + if !session.IsInitialized() { + t.Error("Session should be marked as initialized") + } +} + +func TestHTTPTransportDeleteSession(t *testing.T) { + transport, ts := testHTTPTransport(t, HTTPConfig{}) + + session, _ := transport.sessions.Create() + + // Delete session + req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil) + req.Header.Set("Mcp-Session-Id", session.ID) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + t.Errorf("Expected 204 for delete, got %d", resp.StatusCode) + } + + // Verify session is gone + if transport.sessions.Get(session.ID) != nil { + t.Error("Session should be deleted") + } +} + +func TestHTTPTransportDeleteNonexistentSession(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{}) + + req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil) + req.Header.Set("Mcp-Session-Id", "nonexistent") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportOriginValidation(t *testing.T) { + tests := []struct { + name string + allowedOrigins []string + origin string + expectAllowed bool + }{ + { + name: "no origin header", + allowedOrigins: nil, + origin: "", + expectAllowed: true, + }, + { + name: "localhost allowed by default", + allowedOrigins: nil, + origin: "http://localhost:3000", + expectAllowed: true, + }, + { + name: "127.0.0.1 allowed by default", + allowedOrigins: nil, + origin: "http://127.0.0.1:8080", + expectAllowed: true, + }, + { + name: "external origin blocked by default", + allowedOrigins: nil, + origin: "http://evil.com", + expectAllowed: false, + }, + { + name: "explicit allow", + allowedOrigins: []string{"http://example.com"}, + origin: "http://example.com", + expectAllowed: true, + }, + { + name: "explicit allow wildcard", + allowedOrigins: []string{"*"}, + origin: "http://anything.com", + expectAllowed: true, + }, + { + name: "not in allowed list", + allowedOrigins: []string{"http://example.com"}, + origin: "http://other.com", + expectAllowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + AllowedOrigins: tt.allowedOrigins, + }) + + // Use initialize since it doesn't require a session + initReq := Request{ + JSONRPC: "2.0", + ID: 1, + Method: MethodInitialize, + Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`), + } + body, _ := json.Marshal(initReq) + + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if tt.expectAllowed && resp.StatusCode == http.StatusForbidden { + t.Error("Expected request to be allowed but was forbidden") + } + if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden { + t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode) + } + }) + } +} + +func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) { + transport, ts := testHTTPTransport(t, HTTPConfig{}) + + session, _ := transport.sessions.Create() + + // GET without Accept: text/event-stream + req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil) + req.Header.Set("Mcp-Session-Id", session.ID) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotAcceptable { + t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportSSEStream(t *testing.T) { + transport, ts := testHTTPTransport(t, HTTPConfig{}) + + session, _ := transport.sessions.Create() + + // Start SSE stream in goroutine + req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil) + req.Header.Set("Mcp-Session-Id", session.ID) + req.Header.Set("Accept", "text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected 200, got %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + t.Errorf("Expected Content-Type text/event-stream, got %s", contentType) + } + + // Send a notification + notification := &Response{ + JSONRPC: "2.0", + ID: 42, + Result: map[string]string{"test": "data"}, + } + session.SendNotification(notification) + + // Read the SSE event + buf := make([]byte, 1024) + n, err := resp.Body.Read(buf) + if err != nil && err != io.EOF { + t.Fatalf("Failed to read SSE event: %v", err) + } + + data := string(buf[:n]) + if !strings.HasPrefix(data, "data: ") { + t.Errorf("Expected SSE data event, got: %s", data) + } + + // Parse the JSON from the SSE event + jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ") + var received Response + if err := json.Unmarshal([]byte(jsonData), &received); err != nil { + t.Fatalf("Failed to parse notification JSON: %v", err) + } + + // JSON unmarshal converts numbers to float64, so compare as float64 + receivedID, ok := received.ID.(float64) + if !ok { + t.Fatalf("Expected numeric ID, got %T", received.ID) + } + if int(receivedID) != 42 { + t.Errorf("Expected notification ID 42, got %v", receivedID) + } +} + +func TestHTTPTransportParseError(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{}) + + // Send invalid JSON + req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json"))) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode) + } + + var jsonResp Response + if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + if jsonResp.Error == nil { + t.Error("Expected JSON-RPC error for parse error") + } + if jsonResp.Error != nil && jsonResp.Error.Code != ParseError { + t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code) + } +} + +func TestHTTPTransportMethodNotAllowed(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{}) + + req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Expected 405, got %d", resp.StatusCode) + } +} + +func TestHTTPTransportOptionsRequest(t *testing.T) { + _, ts := testHTTPTransport(t, HTTPConfig{ + AllowedOrigins: []string{"http://example.com"}, + }) + + req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil) + req.Header.Set("Origin", "http://example.com") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNoContent { + t.Errorf("Expected 204, got %d", resp.StatusCode) + } + + if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" { + t.Error("Expected CORS origin header") + } + if resp.Header.Get("Access-Control-Allow-Methods") == "" { + t.Error("Expected CORS methods header") + } +} + +func TestIsLocalhostOrigin(t *testing.T) { + tests := []struct { + origin string + expected bool + }{ + {"http://localhost", true}, + {"http://localhost:3000", true}, + {"https://localhost", true}, + {"https://localhost:8443", true}, + {"http://127.0.0.1", true}, + {"http://127.0.0.1:8080", true}, + {"https://127.0.0.1", true}, + {"http://[::1]", true}, + {"http://[::1]:8080", true}, + {"https://[::1]", true}, + {"http://example.com", false}, + {"https://example.com", false}, + {"http://localhost.evil.com", false}, + {"http://192.168.1.1", false}, + } + + for _, tt := range tests { + t.Run(tt.origin, func(t *testing.T) { + result := isLocalhostOrigin(tt.origin) + if result != tt.expected { + t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected) + } + }) + } +} diff --git a/internal/mcp/transport_stdio.go b/internal/mcp/transport_stdio.go new file mode 100644 index 0000000..20809e6 --- /dev/null +++ b/internal/mcp/transport_stdio.go @@ -0,0 +1,63 @@ +package mcp + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" +) + +// StdioTransport implements the MCP protocol over STDIO using line-delimited JSON-RPC. +type StdioTransport struct { + server *Server + reader io.Reader + writer io.Writer +} + +// NewStdioTransport creates a new STDIO transport. +func NewStdioTransport(server *Server, r io.Reader, w io.Writer) *StdioTransport { + return &StdioTransport{ + server: server, + reader: r, + writer: w, + } +} + +// Run starts the STDIO transport, reading line-delimited JSON-RPC from the reader +// and writing responses to the writer. +func (t *StdioTransport) Run(ctx context.Context) error { + scanner := bufio.NewScanner(t.reader) + encoder := json.NewEncoder(t.writer) + + for scanner.Scan() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + resp, err := t.server.HandleMessage(ctx, line) + if err != nil { + t.server.logger.Printf("Error handling message: %v", err) + continue + } + + if resp != nil { + if err := encoder.Encode(resp); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("scanner error: %w", err) + } + + return nil +} diff --git a/nix/module.nix b/nix/module.nix index 0e98369..be650ed 100644 --- a/nix/module.nix +++ b/nix/module.nix @@ -89,13 +89,56 @@ in ''; }; + http = { + address = lib.mkOption { + type = lib.types.str; + default = "127.0.0.1:8080"; + description = "HTTP listen address for the MCP server."; + }; + + endpoint = lib.mkOption { + type = lib.types.str; + default = "/mcp"; + description = "HTTP endpoint path for MCP requests."; + }; + + allowedOrigins = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ ]; + example = [ "http://localhost:3000" "https://example.com" ]; + description = '' + Allowed Origin headers for CORS. + Empty list means only localhost origins are allowed. + ''; + }; + + sessionTTL = lib.mkOption { + type = lib.types.str; + default = "30m"; + description = "Session TTL for HTTP transport (Go duration format)."; + }; + + tls = { + enable = lib.mkEnableOption "TLS for HTTP transport"; + + certFile = lib.mkOption { + type = lib.types.nullOr lib.types.path; + default = null; + description = "Path to TLS certificate file."; + }; + + keyFile = lib.mkOption { + type = lib.types.nullOr lib.types.path; + default = null; + description = "Path to TLS private key file."; + }; + }; + }; + openFirewall = lib.mkOption { type = lib.types.bool; default = false; - description = '' - Whether to open the firewall for the MCP server. - Note: MCP typically runs over stdio, so this is usually not needed. - ''; + description = "Whether to open the firewall for the MCP HTTP server."; }; }; @@ -111,6 +154,10 @@ in assertion = cfg.database.connectionString == "" || cfg.database.connectionStringFile == null; message = "services.nixos-options-mcp.database: connectionString and connectionStringFile are mutually exclusive"; } + { + assertion = !cfg.http.tls.enable || (cfg.http.tls.certFile != null && cfg.http.tls.keyFile != null); + message = "services.nixos-options-mcp.http.tls: both certFile and keyFile must be set when TLS is enabled"; + } ]; users.users.${cfg.user} = lib.mkIf (cfg.user == "nixos-options-mcp") { @@ -145,6 +192,19 @@ in nixos-options index "${rev}" || true '') cfg.indexOnStart} ''; + + # Build HTTP transport flags + httpFlags = lib.concatStringsSep " " ([ + "--transport http" + "--http-address '${cfg.http.address}'" + "--http-endpoint '${cfg.http.endpoint}'" + "--session-ttl '${cfg.http.sessionTTL}'" + ] ++ lib.optionals (cfg.http.allowedOrigins != []) ( + map (origin: "--allowed-origins '${origin}'") cfg.http.allowedOrigins + ) ++ lib.optionals cfg.http.tls.enable [ + "--tls-cert '${cfg.http.tls.certFile}'" + "--tls-key '${cfg.http.tls.keyFile}'" + ]); in if useConnectionStringFile then '' # Read database connection string from file @@ -155,10 +215,10 @@ in export NIXOS_OPTIONS_DATABASE="$(cat "${cfg.database.connectionStringFile}")" ${indexCommands} - exec nixos-options serve + exec nixos-options serve ${httpFlags} '' else '' ${indexCommands} - exec nixos-options serve + exec nixos-options serve ${httpFlags} ''; serviceConfig = { @@ -188,5 +248,14 @@ in StateDirectory = lib.mkIf (cfg.dataDir == "/var/lib/nixos-options-mcp") "nixos-options-mcp"; }; }; + + # Open firewall for HTTP port if configured + networking.firewall = lib.mkIf cfg.openFirewall (let + # Extract port from address (format: "host:port" or ":port") + addressParts = lib.splitString ":" cfg.http.address; + port = lib.toInt (lib.last addressParts); + in { + allowedTCPPorts = [ port ]; + }); }; }