feat: add Streamable HTTP transport support

Add support for running the MCP server over HTTP with Server-Sent Events
(SSE) using the MCP Streamable HTTP specification, alongside the existing
STDIO transport.

New features:
- Transport abstraction with Transport interface
- HTTP transport with session management
- SSE support for server-initiated notifications
- CORS security with configurable allowed origins
- Optional TLS support
- CLI flags for HTTP configuration (--transport, --http-address, etc.)
- NixOS module options for HTTP transport

The HTTP transport implements:
- POST /mcp: JSON-RPC requests with session management
- GET /mcp: SSE stream for server notifications
- DELETE /mcp: Session termination
- Origin validation (localhost-only by default)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-03 22:02:40 +01:00
parent 0b7333844a
commit cbe55d6456
9 changed files with 1575 additions and 54 deletions

View File

@@ -5,7 +5,10 @@ import (
"fmt"
"log"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/urfave/cli/v2"
@@ -36,7 +39,42 @@ func main() {
Commands: []*cli.Command{
{
Name: "serve",
Usage: "Run MCP server (stdio)",
Usage: "Run MCP server",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "transport",
Aliases: []string{"t"},
Usage: "Transport type: 'stdio' or 'http'",
Value: "stdio",
},
&cli.StringFlag{
Name: "http-address",
Usage: "HTTP listen address",
Value: "127.0.0.1:8080",
},
&cli.StringFlag{
Name: "http-endpoint",
Usage: "HTTP endpoint path",
Value: "/mcp",
},
&cli.StringSliceFlag{
Name: "allowed-origins",
Usage: "Allowed Origin headers for CORS (can be specified multiple times)",
},
&cli.StringFlag{
Name: "tls-cert",
Usage: "TLS certificate file",
},
&cli.StringFlag{
Name: "tls-key",
Usage: "TLS key file",
},
&cli.DurationFlag{
Name: "session-ttl",
Usage: "Session TTL for HTTP transport",
Value: 30 * time.Minute,
},
},
Action: func(c *cli.Context) error {
return runServe(c)
},
@@ -145,7 +183,8 @@ func openStore(connStr string) (database.Store, error) {
}
func runServe(c *cli.Context) error {
ctx := context.Background()
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
store, err := openStore(c.String("database"))
if err != nil {
@@ -163,8 +202,27 @@ func runServe(c *cli.Context) error {
indexer := nixos.NewIndexer(store)
server.RegisterHandlers(indexer)
transport := c.String("transport")
switch transport {
case "stdio":
logger.Println("Starting MCP server on stdio...")
return server.Run(ctx, os.Stdin, os.Stdout)
case "http":
config := mcp.HTTPConfig{
Address: c.String("http-address"),
Endpoint: c.String("http-endpoint"),
AllowedOrigins: c.StringSlice("allowed-origins"),
SessionTTL: c.Duration("session-ttl"),
TLSCertFile: c.String("tls-cert"),
TLSKeyFile: c.String("tls-key"),
}
httpTransport := mcp.NewHTTPTransport(server, config)
return httpTransport.Run(ctx)
default:
return fmt.Errorf("unknown transport: %s (use 'stdio' or 'http')", transport)
}
}
func runIndex(c *cli.Context, revision string, indexFiles bool, force bool) error {

View File

@@ -1,7 +1,6 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
@@ -11,7 +10,7 @@ import (
"git.t-juice.club/torjus/labmcp/internal/database"
)
// Server is an MCP server that handles JSON-RPC requests over stdio.
// Server is an MCP server that handles JSON-RPC requests.
type Server struct {
store database.Store
tools map[string]ToolHandler
@@ -41,53 +40,34 @@ func (s *Server) registerTools() {
// Tools will be implemented in handlers.go
}
// Run starts the server, reading from r and writing to w.
// Run starts the server using STDIO transport (backward compatibility).
func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error {
scanner := bufio.NewScanner(r)
encoder := json.NewEncoder(w)
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line := scanner.Bytes()
if len(line) == 0 {
continue
transport := NewStdioTransport(s, r, w)
return transport.Run(ctx)
}
// HandleMessage parses a JSON-RPC message and returns the response.
// Returns (nil, nil) for notifications that don't require a response.
func (s *Server) HandleMessage(ctx context.Context, data []byte) (*Response, error) {
var req Request
if err := json.Unmarshal(line, &req); err != nil {
s.logger.Printf("Failed to parse request: %v", err)
resp := Response{
if err := json.Unmarshal(data, &req); err != nil {
return &Response{
JSONRPC: "2.0",
Error: &Error{
Code: ParseError,
Message: "Parse error",
Data: err.Error(),
},
}
if err := encoder.Encode(resp); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
continue
}, nil
}
resp := s.handleRequest(ctx, &req)
if resp != nil {
if err := encoder.Encode(resp); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
return s.HandleRequest(ctx, &req), nil
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %w", err)
}
return nil
// HandleRequest processes a single request and returns a response.
// Returns nil for notifications that don't require a response.
func (s *Server) HandleRequest(ctx context.Context, req *Request) *Response {
return s.handleRequest(ctx, req)
}
// handleRequest processes a single request and returns a response.

207
internal/mcp/session.go Normal file
View File

@@ -0,0 +1,207 @@
package mcp
import (
"crypto/rand"
"encoding/hex"
"sync"
"time"
)
// Session represents an MCP client session.
type Session struct {
ID string
CreatedAt time.Time
LastActivity time.Time
Initialized bool
// notifications is a channel for server-initiated notifications.
notifications chan *Response
mu sync.RWMutex
}
// NewSession creates a new session with a cryptographically secure random ID.
func NewSession() (*Session, error) {
id, err := generateSessionID()
if err != nil {
return nil, err
}
now := time.Now()
return &Session{
ID: id,
CreatedAt: now,
LastActivity: now,
notifications: make(chan *Response, 100),
}, nil
}
// Touch updates the session's last activity time.
func (s *Session) Touch() {
s.mu.Lock()
defer s.mu.Unlock()
s.LastActivity = time.Now()
}
// SetInitialized marks the session as initialized.
func (s *Session) SetInitialized() {
s.mu.Lock()
defer s.mu.Unlock()
s.Initialized = true
}
// IsInitialized returns whether the session has been initialized.
func (s *Session) IsInitialized() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Initialized
}
// Notifications returns the channel for server-initiated notifications.
func (s *Session) Notifications() <-chan *Response {
return s.notifications
}
// SendNotification sends a notification to the session.
// Returns false if the channel is full.
func (s *Session) SendNotification(notification *Response) bool {
select {
case s.notifications <- notification:
return true
default:
return false
}
}
// Close closes the session's notification channel.
func (s *Session) Close() {
close(s.notifications)
}
// SessionStore manages active sessions with TTL-based cleanup.
type SessionStore struct {
sessions map[string]*Session
ttl time.Duration
mu sync.RWMutex
stopClean chan struct{}
cleanDone chan struct{}
}
// NewSessionStore creates a new session store with the given TTL.
func NewSessionStore(ttl time.Duration) *SessionStore {
s := &SessionStore{
sessions: make(map[string]*Session),
ttl: ttl,
stopClean: make(chan struct{}),
cleanDone: make(chan struct{}),
}
go s.cleanupLoop()
return s
}
// Create creates a new session and adds it to the store.
func (s *SessionStore) Create() (*Session, error) {
session, err := NewSession()
if err != nil {
return nil, err
}
s.mu.Lock()
defer s.mu.Unlock()
s.sessions[session.ID] = session
return session, nil
}
// Get retrieves a session by ID. Returns nil if not found or expired.
func (s *SessionStore) Get(id string) *Session {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[id]
if !ok {
return nil
}
// Check if expired
session.mu.RLock()
expired := time.Since(session.LastActivity) > s.ttl
session.mu.RUnlock()
if expired {
return nil
}
return session
}
// Delete removes a session from the store.
func (s *SessionStore) Delete(id string) bool {
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.sessions[id]
if !ok {
return false
}
session.Close()
delete(s.sessions, id)
return true
}
// Count returns the number of active sessions.
func (s *SessionStore) Count() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.sessions)
}
// Stop stops the cleanup goroutine and waits for it to finish.
func (s *SessionStore) Stop() {
close(s.stopClean)
<-s.cleanDone
}
// cleanupLoop periodically removes expired sessions.
func (s *SessionStore) cleanupLoop() {
defer close(s.cleanDone)
ticker := time.NewTicker(s.ttl / 2)
defer ticker.Stop()
for {
select {
case <-s.stopClean:
return
case <-ticker.C:
s.cleanup()
}
}
}
// cleanup removes expired sessions.
func (s *SessionStore) cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range s.sessions {
session.mu.RLock()
expired := now.Sub(session.LastActivity) > s.ttl
session.mu.RUnlock()
if expired {
session.Close()
delete(s.sessions, id)
}
}
}
// generateSessionID generates a cryptographically secure random session ID.
func generateSessionID() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -0,0 +1,267 @@
package mcp
import (
"sync"
"testing"
"time"
)
func TestNewSession(t *testing.T) {
session, err := NewSession()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if session.ID == "" {
t.Error("Session ID should not be empty")
}
if len(session.ID) != 32 {
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
}
if session.Initialized {
t.Error("New session should not be initialized")
}
}
func TestSessionTouch(t *testing.T) {
session, _ := NewSession()
originalActivity := session.LastActivity
time.Sleep(10 * time.Millisecond)
session.Touch()
if !session.LastActivity.After(originalActivity) {
t.Error("Touch should update LastActivity")
}
}
func TestSessionInitialized(t *testing.T) {
session, _ := NewSession()
if session.IsInitialized() {
t.Error("New session should not be initialized")
}
session.SetInitialized()
if !session.IsInitialized() {
t.Error("Session should be initialized after SetInitialized")
}
}
func TestSessionNotifications(t *testing.T) {
session, _ := NewSession()
defer session.Close()
notification := &Response{JSONRPC: "2.0", ID: 1}
if !session.SendNotification(notification) {
t.Error("SendNotification should return true on success")
}
select {
case received := <-session.Notifications():
if received.ID != notification.ID {
t.Error("Received notification should match sent")
}
case <-time.After(100 * time.Millisecond):
t.Error("Should receive notification")
}
}
func TestSessionStoreCreate(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
session, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if store.Count() != 1 {
t.Errorf("Store should have 1 session, got %d", store.Count())
}
// Verify we can retrieve it
retrieved := store.Get(session.ID)
if retrieved == nil {
t.Error("Should be able to retrieve created session")
}
if retrieved.ID != session.ID {
t.Error("Retrieved session ID should match")
}
}
func TestSessionStoreGet(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Get non-existent session
if store.Get("nonexistent") != nil {
t.Error("Should return nil for non-existent session")
}
// Create and retrieve
session, _ := store.Create()
retrieved := store.Get(session.ID)
if retrieved == nil {
t.Error("Should find created session")
}
}
func TestSessionStoreDelete(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
session, _ := store.Create()
if store.Count() != 1 {
t.Error("Should have 1 session after create")
}
if !store.Delete(session.ID) {
t.Error("Delete should return true for existing session")
}
if store.Count() != 0 {
t.Error("Should have 0 sessions after delete")
}
if store.Delete(session.ID) {
t.Error("Delete should return false for non-existent session")
}
}
func TestSessionStoreTTLExpiration(t *testing.T) {
ttl := 50 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
session, _ := store.Create()
// Should be retrievable immediately
if store.Get(session.ID) == nil {
t.Error("Session should be retrievable immediately")
}
// Wait for expiration
time.Sleep(ttl + 10*time.Millisecond)
// Should not be retrievable after TTL
if store.Get(session.ID) != nil {
t.Error("Expired session should not be retrievable")
}
}
func TestSessionStoreTTLRefresh(t *testing.T) {
ttl := 100 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
session, _ := store.Create()
// Touch the session before TTL expires
time.Sleep(60 * time.Millisecond)
session.Touch()
// Wait past original TTL but not past refreshed TTL
time.Sleep(60 * time.Millisecond)
// Should still be retrievable because we touched it
if store.Get(session.ID) == nil {
t.Error("Touched session should still be retrievable")
}
}
func TestSessionStoreCleanup(t *testing.T) {
ttl := 50 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
// Create multiple sessions
for i := 0; i < 5; i++ {
store.Create()
}
if store.Count() != 5 {
t.Errorf("Should have 5 sessions, got %d", store.Count())
}
// Wait for cleanup to run (runs at ttl/2 intervals)
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
// All sessions should be cleaned up
if store.Count() != 0 {
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
}
}
func TestSessionStoreConcurrency(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
var wg sync.WaitGroup
sessionIDs := make(chan string, 100)
// Create sessions concurrently
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
session, err := store.Create()
if err != nil {
t.Errorf("Failed to create session: %v", err)
return
}
sessionIDs <- session.ID
}()
}
wg.Wait()
close(sessionIDs)
// Verify all sessions were created
if store.Count() != 50 {
t.Errorf("Should have 50 sessions, got %d", store.Count())
}
// Read and delete concurrently
var ids []string
for id := range sessionIDs {
ids = append(ids, id)
}
for _, id := range ids {
wg.Add(2)
go func(id string) {
defer wg.Done()
store.Get(id)
}(id)
go func(id string) {
defer wg.Done()
store.Delete(id)
}(id)
}
wg.Wait()
}
func TestGenerateSessionID(t *testing.T) {
ids := make(map[string]bool)
// Generate 1000 IDs and ensure uniqueness
for i := 0; i < 1000; i++ {
id, err := generateSessionID()
if err != nil {
t.Fatalf("Failed to generate session ID: %v", err)
}
if len(id) != 32 {
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
}
if ids[id] {
t.Error("Generated duplicate session ID")
}
ids[id] = true
}
}

10
internal/mcp/transport.go Normal file
View File

@@ -0,0 +1,10 @@
package mcp
import "context"
// Transport defines the interface for MCP server transports.
type Transport interface {
// Run starts the transport and blocks until the context is cancelled
// or an error occurs.
Run(ctx context.Context) error
}

View File

@@ -0,0 +1,354 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
)
// HTTPConfig configures the HTTP transport.
type HTTPConfig struct {
Address string // Listen address (e.g., "127.0.0.1:8080")
Endpoint string // MCP endpoint path (e.g., "/mcp")
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
SessionTTL time.Duration // Session TTL (default: 30 minutes)
TLSCertFile string // TLS certificate file (optional)
TLSKeyFile string // TLS key file (optional)
}
// HTTPTransport implements the MCP Streamable HTTP transport.
type HTTPTransport struct {
server *Server
config HTTPConfig
sessions *SessionStore
}
// NewHTTPTransport creates a new HTTP transport.
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
if config.Address == "" {
config.Address = "127.0.0.1:8080"
}
if config.Endpoint == "" {
config.Endpoint = "/mcp"
}
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
return &HTTPTransport{
server: server,
config: config,
sessions: NewSessionStore(config.SessionTTL),
}
}
// Run starts the HTTP server and blocks until the context is cancelled.
func (t *HTTPTransport) Run(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
httpServer := &http.Server{
Addr: t.config.Address,
Handler: mux,
BaseContext: func(l net.Listener) context.Context {
return ctx
},
}
// Graceful shutdown on context cancellation
go func() {
<-ctx.Done()
t.server.logger.Println("Shutting down HTTP server...")
t.sessions.Stop()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
t.server.logger.Printf("HTTP server shutdown error: %v", err)
}
}()
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
var err error
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
} else {
err = httpServer.ListenAndServe()
}
if err == http.ErrServerClosed {
return nil
}
return err
}
// handleMCP routes requests based on HTTP method.
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
// Validate Origin header
if !t.isOriginAllowed(r) {
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
return
}
switch r.Method {
case http.MethodPost:
t.handlePost(w, r)
case http.MethodGet:
t.handleGet(w, r)
case http.MethodDelete:
t.handleDelete(w, r)
case http.MethodOptions:
t.handleOptions(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handlePost handles JSON-RPC requests.
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
// Read request body
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
// Parse request to check method
var req Request
if err := json.Unmarshal(body, &req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
JSONRPC: "2.0",
Error: &Error{
Code: ParseError,
Message: "Parse error",
Data: err.Error(),
},
})
return
}
// Handle initialize request - create session
if req.Method == MethodInitialize {
t.handleInitialize(w, r, &req)
return
}
// All other requests require a valid session
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
session := t.sessions.Get(sessionID)
if session == nil {
http.Error(w, "Invalid or expired session", http.StatusNotFound)
return
}
// Update session activity
session.Touch()
// Handle notifications (no response expected)
if req.Method == MethodInitialized {
session.SetInitialized()
w.WriteHeader(http.StatusAccepted)
return
}
// Check if this is a notification (no ID)
if req.ID == nil {
w.WriteHeader(http.StatusAccepted)
return
}
// Process the request
resp := t.server.HandleRequest(r.Context(), &req)
if resp == nil {
w.WriteHeader(http.StatusAccepted)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
// handleInitialize handles the initialize request and creates a new session.
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
// Create a new session
session, err := t.sessions.Create()
if err != nil {
t.server.logger.Printf("Failed to create session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
// Process initialize request
resp := t.server.HandleRequest(r.Context(), req)
if resp == nil {
t.sessions.Delete(session.ID)
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
return
}
// If initialize failed, clean up session
if resp.Error != nil {
t.sessions.Delete(session.ID)
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Mcp-Session-Id", session.ID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
// handleGet handles SSE stream for server-initiated notifications.
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
session := t.sessions.Get(sessionID)
if session == nil {
http.Error(w, "Invalid or expired session", http.StatusNotFound)
return
}
// Check if client accepts SSE
accept := r.Header.Get("Accept")
if !strings.Contains(accept, "text/event-stream") {
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
// Flush headers
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
flusher.Flush()
// Stream notifications
ctx := r.Context()
for {
select {
case <-ctx.Done():
return
case notification, ok := <-session.Notifications():
if !ok {
// Session closed
return
}
data, err := json.Marshal(notification)
if err != nil {
t.server.logger.Printf("Failed to marshal notification: %v", err)
continue
}
// Write SSE event
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
// Touch session to keep it alive
session.Touch()
}
}
}
// handleDelete terminates a session.
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
if t.sessions.Delete(sessionID) {
w.WriteHeader(http.StatusNoContent)
} else {
http.Error(w, "Session not found", http.StatusNotFound)
}
}
// handleOptions handles CORS preflight requests.
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && t.isOriginAllowed(r) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
w.Header().Set("Access-Control-Max-Age", "86400")
}
w.WriteHeader(http.StatusNoContent)
}
// isOriginAllowed checks if the request origin is allowed.
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
origin := r.Header.Get("Origin")
// No Origin header (same-origin request) is always allowed
if origin == "" {
return true
}
// If no allowed origins configured, only allow localhost
if len(t.config.AllowedOrigins) == 0 {
return isLocalhostOrigin(origin)
}
// Check against allowed origins
for _, allowed := range t.config.AllowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
}
return false
}
// isLocalhostOrigin checks if the origin is a localhost address.
func isLocalhostOrigin(origin string) bool {
origin = strings.ToLower(origin)
// Check for localhost patterns (must be followed by :, /, or end of string)
localhostPatterns := []string{
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
}
for _, pattern := range localhostPatterns {
if origin == pattern {
return true
}
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
return true
}
}
return false
}

View File

@@ -0,0 +1,513 @@
package mcp
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// testHTTPTransport creates a transport with a test server
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
// Use a mock store
server := NewServer(nil, log.New(io.Discard, "", 0))
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
transport := NewHTTPTransport(server, config)
// Create test server
mux := http.NewServeMux()
endpoint := config.Endpoint
if endpoint == "" {
endpoint = "/mcp"
}
mux.HandleFunc(endpoint, transport.handleMCP)
ts := httptest.NewServer(mux)
t.Cleanup(func() {
ts.Close()
transport.sessions.Stop()
})
return transport, ts
}
func TestHTTPTransportInitialize(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send initialize request
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200, got %d", resp.StatusCode)
}
// Check session ID header
sessionID := resp.Header.Get("Mcp-Session-Id")
if sessionID == "" {
t.Error("Expected Mcp-Session-Id header")
}
if len(sessionID) != 32 {
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
}
// Check response body
var initResp Response
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if initResp.Error != nil {
t.Errorf("Initialize failed: %v", initResp.Error)
}
}
func TestHTTPTransportSessionRequired(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send tools/list without session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportInvalidSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send request with invalid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportValidSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
// Create session manually
session, _ := transport.sessions.Create()
// Send tools/list with valid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportNotificationAccepted(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Send notification (no ID)
notification := Request{
JSONRPC: "2.0",
Method: MethodInitialized,
}
body, _ := json.Marshal(notification)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
}
// Verify session is marked as initialized
if !session.IsInitialized() {
t.Error("Session should be marked as initialized")
}
}
func TestHTTPTransportDeleteSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Delete session
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
}
// Verify session is gone
if transport.sessions.Get(session.ID) != nil {
t.Error("Session should be deleted")
}
}
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", "nonexistent")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOriginValidation(t *testing.T) {
tests := []struct {
name string
allowedOrigins []string
origin string
expectAllowed bool
}{
{
name: "no origin header",
allowedOrigins: nil,
origin: "",
expectAllowed: true,
},
{
name: "localhost allowed by default",
allowedOrigins: nil,
origin: "http://localhost:3000",
expectAllowed: true,
},
{
name: "127.0.0.1 allowed by default",
allowedOrigins: nil,
origin: "http://127.0.0.1:8080",
expectAllowed: true,
},
{
name: "external origin blocked by default",
allowedOrigins: nil,
origin: "http://evil.com",
expectAllowed: false,
},
{
name: "explicit allow",
allowedOrigins: []string{"http://example.com"},
origin: "http://example.com",
expectAllowed: true,
},
{
name: "explicit allow wildcard",
allowedOrigins: []string{"*"},
origin: "http://anything.com",
expectAllowed: true,
},
{
name: "not in allowed list",
allowedOrigins: []string{"http://example.com"},
origin: "http://other.com",
expectAllowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: tt.allowedOrigins,
})
// Use initialize since it doesn't require a session
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
t.Error("Expected request to be allowed but was forbidden")
}
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
}
})
}
}
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// GET without Accept: text/event-stream
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotAcceptable {
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
}
}
func TestHTTPTransportSSEStream(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Start SSE stream in goroutine
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected 200, got %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
}
// Send a notification
notification := &Response{
JSONRPC: "2.0",
ID: 42,
Result: map[string]string{"test": "data"},
}
session.SendNotification(notification)
// Read the SSE event
buf := make([]byte, 1024)
n, err := resp.Body.Read(buf)
if err != nil && err != io.EOF {
t.Fatalf("Failed to read SSE event: %v", err)
}
data := string(buf[:n])
if !strings.HasPrefix(data, "data: ") {
t.Errorf("Expected SSE data event, got: %s", data)
}
// Parse the JSON from the SSE event
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
var received Response
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
t.Fatalf("Failed to parse notification JSON: %v", err)
}
// JSON unmarshal converts numbers to float64, so compare as float64
receivedID, ok := received.ID.(float64)
if !ok {
t.Fatalf("Expected numeric ID, got %T", received.ID)
}
if int(receivedID) != 42 {
t.Errorf("Expected notification ID 42, got %v", receivedID)
}
}
func TestHTTPTransportParseError(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send invalid JSON
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
}
var jsonResp Response
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if jsonResp.Error == nil {
t.Error("Expected JSON-RPC error for parse error")
}
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
}
}
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected 405, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOptionsRequest(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: []string{"http://example.com"},
})
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
req.Header.Set("Origin", "http://example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204, got %d", resp.StatusCode)
}
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Error("Expected CORS origin header")
}
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
t.Error("Expected CORS methods header")
}
}
func TestIsLocalhostOrigin(t *testing.T) {
tests := []struct {
origin string
expected bool
}{
{"http://localhost", true},
{"http://localhost:3000", true},
{"https://localhost", true},
{"https://localhost:8443", true},
{"http://127.0.0.1", true},
{"http://127.0.0.1:8080", true},
{"https://127.0.0.1", true},
{"http://[::1]", true},
{"http://[::1]:8080", true},
{"https://[::1]", true},
{"http://example.com", false},
{"https://example.com", false},
{"http://localhost.evil.com", false},
{"http://192.168.1.1", false},
}
for _, tt := range tests {
t.Run(tt.origin, func(t *testing.T) {
result := isLocalhostOrigin(tt.origin)
if result != tt.expected {
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,63 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
)
// StdioTransport implements the MCP protocol over STDIO using line-delimited JSON-RPC.
type StdioTransport struct {
server *Server
reader io.Reader
writer io.Writer
}
// NewStdioTransport creates a new STDIO transport.
func NewStdioTransport(server *Server, r io.Reader, w io.Writer) *StdioTransport {
return &StdioTransport{
server: server,
reader: r,
writer: w,
}
}
// Run starts the STDIO transport, reading line-delimited JSON-RPC from the reader
// and writing responses to the writer.
func (t *StdioTransport) Run(ctx context.Context) error {
scanner := bufio.NewScanner(t.reader)
encoder := json.NewEncoder(t.writer)
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line := scanner.Bytes()
if len(line) == 0 {
continue
}
resp, err := t.server.HandleMessage(ctx, line)
if err != nil {
t.server.logger.Printf("Error handling message: %v", err)
continue
}
if resp != nil {
if err := encoder.Encode(resp); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %w", err)
}
return nil
}

View File

@@ -89,13 +89,56 @@ in
'';
};
http = {
address = lib.mkOption {
type = lib.types.str;
default = "127.0.0.1:8080";
description = "HTTP listen address for the MCP server.";
};
endpoint = lib.mkOption {
type = lib.types.str;
default = "/mcp";
description = "HTTP endpoint path for MCP requests.";
};
allowedOrigins = lib.mkOption {
type = lib.types.listOf lib.types.str;
default = [ ];
example = [ "http://localhost:3000" "https://example.com" ];
description = ''
Allowed Origin headers for CORS.
Empty list means only localhost origins are allowed.
'';
};
sessionTTL = lib.mkOption {
type = lib.types.str;
default = "30m";
description = "Session TTL for HTTP transport (Go duration format).";
};
tls = {
enable = lib.mkEnableOption "TLS for HTTP transport";
certFile = lib.mkOption {
type = lib.types.nullOr lib.types.path;
default = null;
description = "Path to TLS certificate file.";
};
keyFile = lib.mkOption {
type = lib.types.nullOr lib.types.path;
default = null;
description = "Path to TLS private key file.";
};
};
};
openFirewall = lib.mkOption {
type = lib.types.bool;
default = false;
description = ''
Whether to open the firewall for the MCP server.
Note: MCP typically runs over stdio, so this is usually not needed.
'';
description = "Whether to open the firewall for the MCP HTTP server.";
};
};
@@ -111,6 +154,10 @@ in
assertion = cfg.database.connectionString == "" || cfg.database.connectionStringFile == null;
message = "services.nixos-options-mcp.database: connectionString and connectionStringFile are mutually exclusive";
}
{
assertion = !cfg.http.tls.enable || (cfg.http.tls.certFile != null && cfg.http.tls.keyFile != null);
message = "services.nixos-options-mcp.http.tls: both certFile and keyFile must be set when TLS is enabled";
}
];
users.users.${cfg.user} = lib.mkIf (cfg.user == "nixos-options-mcp") {
@@ -145,6 +192,19 @@ in
nixos-options index "${rev}" || true
'') cfg.indexOnStart}
'';
# Build HTTP transport flags
httpFlags = lib.concatStringsSep " " ([
"--transport http"
"--http-address '${cfg.http.address}'"
"--http-endpoint '${cfg.http.endpoint}'"
"--session-ttl '${cfg.http.sessionTTL}'"
] ++ lib.optionals (cfg.http.allowedOrigins != []) (
map (origin: "--allowed-origins '${origin}'") cfg.http.allowedOrigins
) ++ lib.optionals cfg.http.tls.enable [
"--tls-cert '${cfg.http.tls.certFile}'"
"--tls-key '${cfg.http.tls.keyFile}'"
]);
in
if useConnectionStringFile then ''
# Read database connection string from file
@@ -155,10 +215,10 @@ in
export NIXOS_OPTIONS_DATABASE="$(cat "${cfg.database.connectionStringFile}")"
${indexCommands}
exec nixos-options serve
exec nixos-options serve ${httpFlags}
'' else ''
${indexCommands}
exec nixos-options serve
exec nixos-options serve ${httpFlags}
'';
serviceConfig = {
@@ -188,5 +248,14 @@ in
StateDirectory = lib.mkIf (cfg.dataDir == "/var/lib/nixos-options-mcp") "nixos-options-mcp";
};
};
# Open firewall for HTTP port if configured
networking.firewall = lib.mkIf cfg.openFirewall (let
# Extract port from address (format: "host:port" or ":port")
addressParts = lib.splitString ":" cfg.http.address;
port = lib.toInt (lib.last addressParts);
in {
allowedTCPPorts = [ port ];
});
};
}