Compare commits

...

9 Commits

Author SHA1 Message Date
6b6be83e50 Merge pull request 'feature/streamable-http-transport' (#1) from feature/streamable-http-transport into master
Reviewed-on: #1
2026-02-03 21:23:38 +00:00
e6315eb94b docs: fix flake URL and add nix run MCP example
- Update flake URL from github:torjus/labmcp to the correct
  git+https://git.t-juice.club/torjus/labmcp
- Add alternative MCP client configuration using nix run with
  the flake URL directly (no installation required)
- Fix NixOS module example to use correct flake URL

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:21:50 +01:00
921005179e docs: add HTTP transport documentation
Update README.md:
- Add HTTP transport usage section with examples
- Document HTTP endpoints (POST/GET/DELETE)
- Add HTTP-related NixOS module options to the table

Update CLAUDE.md:
- Update protocol description to include HTTP/SSE
- Add new transport files to repository structure
- Add Transports section explaining STDIO vs HTTP
- Add HTTP security hardening details
- Update CLI commands with HTTP transport examples

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:16:46 +01:00
08f8b2cd83 feat: add SSE keepalive messages for connection health
Add configurable SSEKeepAlive interval (default: 15s) that sends SSE
comment lines (`:keepalive`) to maintain connection health.

Benefits:
- Keeps connections alive through proxies/load balancers that timeout
  idle connections
- Detects stale connections earlier (write failures terminate the
  handler)
- Standard SSE pattern - comments are ignored by compliant clients

Configuration:
- SSEKeepAlive > 0: send keepalives at specified interval
- SSEKeepAlive = 0: use default (15s)
- SSEKeepAlive < 0: disable keepalives

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:10:58 +01:00
684baf63da security: add maximum session limit to prevent memory exhaustion
Add configurable MaxSessions limit (default: 10000) to SessionStore.
When the limit is reached, new session creation returns ErrTooManySessions
and HTTP transport responds with 503 Service Unavailable.

This prevents attackers from exhausting server memory by creating
unlimited sessions through repeated initialize requests.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:07:51 +01:00
1565cb5e1b security: add HTTP server timeouts to prevent slowloris attacks
Configure HTTP server with sensible timeouts:
- ReadTimeout: 30s (time to read entire request)
- WriteTimeout: 30s (time to write response)
- IdleTimeout: 120s (keep-alive connection timeout)
- ReadHeaderTimeout: 10s (time to read request headers)

For SSE connections, use http.ResponseController to extend write
deadlines before each write, preventing timeout on long-lived streams
while still protecting against slow clients.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:05:33 +01:00
149832e4e5 security: add request body size limit to prevent DoS
Add MaxRequestSize configuration to HTTPConfig with a default of 1MB.
Use http.MaxBytesReader to enforce the limit, returning 413 Request
Entity Too Large when exceeded.

This prevents memory exhaustion attacks where an attacker sends
arbitrarily large request bodies.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:04:11 +01:00
cbe55d6456 feat: add Streamable HTTP transport support
Add support for running the MCP server over HTTP with Server-Sent Events
(SSE) using the MCP Streamable HTTP specification, alongside the existing
STDIO transport.

New features:
- Transport abstraction with Transport interface
- HTTP transport with session management
- SSE support for server-initiated notifications
- CORS security with configurable allowed origins
- Optional TLS support
- CLI flags for HTTP configuration (--transport, --http-address, etc.)
- NixOS module options for HTTP transport

The HTTP transport implements:
- POST /mcp: JSON-RPC requests with session management
- GET /mcp: SSE stream for server notifications
- DELETE /mcp: Session termination
- Origin validation (localhost-only by default)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:02:40 +01:00
0b7333844a docs: update CLAUDE.md to reflect current project state
The file was still showing "Planning phase" with outdated next steps.
Updated to reflect the complete implementation:

- Changed status to "Complete and maintained"
- Updated repository structure to match actual layout
- Documented all 6 MCP tools as implemented
- Added key implementation details (database, indexing, security)
- Added CLI command reference
- Consolidated development notes
- Removed obsolete planning sections

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 21:40:59 +01:00
12 changed files with 2180 additions and 202 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
result result
*.db

269
CLAUDE.md
View File

@@ -12,176 +12,157 @@ The first MCP server provides search and query capabilities for NixOS configurat
## Technology Stack ## Technology Stack
- **Language**: Go 1.25.5 - **Language**: Go 1.24+
- **Build System**: Nix flakes - **Build System**: Nix flakes
- **Databases**: PostgreSQL (primary) and SQLite (lightweight alternative) - **Databases**: PostgreSQL and SQLite (both fully supported)
- **Protocol**: MCP (Model Context Protocol) - JSON-RPC over stdio - **Protocol**: MCP (Model Context Protocol) - JSON-RPC over STDIO or HTTP/SSE
- **Module Path**: `git.t-juice.club/torjus/labmcp` - **Module Path**: `git.t-juice.club/torjus/labmcp`
## Key Architectural Decisions ## Project Status
1. **Database Support**: Both PostgreSQL and SQLite **Complete and maintained** - All core features implemented:
- PostgreSQL is preferred for production use (user's preference) - Full MCP server with 6 tools
- SQLite provides lightweight alternative for simpler deployments - PostgreSQL and SQLite backends with FTS
- Use Go's `database/sql` interface for abstraction - NixOS module for deployment
- CLI for manual operations
- Comprehensive test suite
2. **File Storage**: Store nixpkgs file contents in database during indexing ## Repository Structure
- Better performance for the `get_file` tool
- PostgreSQL handles large text storage well
3. **Revision Management**: Support multiple indexed nixpkgs revisions
- Store git hash, date, channel name, option count
- Allow querying specific revisions or use defaults
- Default revision: nixos-stable (configurable)
4. **Indexing Approach**: Part of MCP server, blocking operation (initially)
- Allows Claude to read flake.lock and request indexing
- Can optimize to async later if needed
5. **Testing**: Aim for >80% test coverage
- Unit tests for all components
- Integration tests for full workflows
- Benchmarks for indexing and query performance
## MCP Tools to Implement
### Core Search & Query
1. **`search_options`** - Fuzzy/partial matching search
- Parameters: revision, query, optional filters (type, namespace, hasDefault)
- Returns: matching options with basic metadata
2. **`get_option`** - Get full details for specific option
- Parameters: revision, option_path, optional depth
- Returns: name, type, default, example, description, file paths
- Default: direct children only (one level deep)
- Includes related/nearby options in same namespace
3. **`get_file`** - Fetch nixpkgs source file contents
- Parameters: revision, file_path
- Returns: file contents
- Security: validate paths, no traversal, nixpkgs-only
### Revision Management
4. **`index_revision`** - Index a specific nixpkgs revision
- Parameters: git_hash (full or short)
- Process: fetch nixpkgs, extract options.json, populate DB
- Returns: summary (option count, duration, etc.)
5. **`list_revisions`** - List indexed revisions
- Returns: git hash, date, channel name, option count
6. **`delete_revision`** - Prune old/unused revisions
- Parameters: revision identifier
- Returns: confirmation of deletion
### Channel Support
- Support friendly aliases: `nixos-unstable`, `nixos-24.05`, `nixos-23.11`, etc.
- Can be used in place of git hashes in all tools
## Database Schema
**Tables:**
1. `revisions` - Indexed nixpkgs versions
- id, git_hash (unique), channel_name, commit_date, indexed_at, option_count
2. `options` - NixOS options with hierarchy support
- id, revision_id (FK), name, parent_path, type, default_value (JSON text), example (JSON text), description, read_only
- parent_path enables efficient "list children" queries (derived from name)
3. `declarations` - File paths where options are declared
- id, option_id (FK), file_path, line_number
4. `files` - Cached file contents
- id, revision_id (FK), file_path, extension, content
- Configurable whitelist of extensions (default: .nix, .json, .md, .txt, .toml, .yaml, .yml)
**Indexes:**
- Full-text search: PostgreSQL (tsvector/GIN), SQLite (FTS5)
- B-tree on (revision_id, name) and (revision_id, parent_path)
- B-tree on (revision_id, file_path) for file lookups
**Cross-DB Compatibility:**
- JSON stored as TEXT (not JSONB) for SQLite compatibility
- Separate FTS implementations per database engine
## Repository Structure (Planned)
``` ```
labmcp/ labmcp/
├── cmd/ ├── cmd/
│ └── nixos-options/ # MCP server binary │ └── nixos-options/
│ └── main.go │ └── main.go # CLI entry point
├── internal/ ├── internal/
│ ├── mcp/ # MCP protocol implementation │ ├── database/
│ │ ├── server.go │ │ ├── interface.go # Store interface
│ │ ── types.go │ │ ── schema.go # Schema versioning
│ ├── database/ # Database abstraction │ ├── postgres.go # PostgreSQL implementation
│ │ ├── interface.go │ │ ├── sqlite.go # SQLite implementation
│ │ ── postgres.go │ │ ── *_test.go # Database tests
│ └── sqlite.go ├── mcp/
└── nixos/ # NixOS options specific logic │ ├── server.go # MCP server core
├── search.go ├── handlers.go # Tool implementations
── types.go ── types.go # Protocol types
├── scripts/ │ │ ├── transport.go # Transport interface
└── populate-db.go # Tool to populate database │ ├── transport_stdio.go # STDIO transport
├── schema/ │ │ ├── transport_http.go # HTTP/SSE transport
└── schema.sql # Database schema │ ├── session.go # HTTP session management
├── flake.nix # Nix build configuration └── *_test.go # MCP tests
│ └── nixos/
│ ├── indexer.go # Nixpkgs indexing
│ ├── parser.go # options.json parsing
│ ├── types.go # Channel aliases, extensions
│ └── *_test.go # Indexer tests
├── nix/
│ ├── module.nix # NixOS module
│ └── package.nix # Nix package definition
├── testdata/
│ └── options-sample.json # Test fixture
├── flake.nix
├── go.mod ├── go.mod
├── TODO.md # Detailed task list ├── .mcp.json # MCP client configuration
├── CLAUDE.md # This file ├── CLAUDE.md # This file
── README.md ── README.md
└── TODO.md # Future improvements
``` ```
## Use Cases ## MCP Tools
**Primary Use Case**: Claude can help users find and understand NixOS options All tools are implemented and functional:
- "What options are available for nginx?"
- "Show me the services.caddy.* options"
- "What's the default value for services.postgresql.enable?"
- User shares a flake.lock → Claude indexes that nixpkgs version → answers questions about options in that specific version
**Secondary Use Case**: Explore module implementations | Tool | Description |
- If option description is unclear, fetch the actual module source |------|-------------|
- Understand how complex options are structured | `search_options` | Full-text search across option names and descriptions |
| `get_option` | Get full details for a specific option with children |
| `get_file` | Fetch source file contents from indexed nixpkgs |
| `index_revision` | Index a nixpkgs revision (by hash or channel name) |
| `list_revisions` | List all indexed revisions |
| `delete_revision` | Delete an indexed revision |
## Testing Strategy ## Key Implementation Details
- **Unit Tests**: All components with mocks where appropriate ### Database
- **Integration Tests**: Full indexing pipeline, MCP tool invocations - Schema versioning with automatic recreation on version mismatch
- **Benchmarks**: Indexing time, query performance, memory usage - Full-text search: SQLite FTS5, PostgreSQL tsvector/GIN
- **Test Fixtures**: Sample options.json, mock repositories - Path-based queries use LIKE for exact prefix matching
- **Coverage Goal**: >80% on core logic, 100% on database operations - Batch operations for efficient indexing
## Open Questions ### Indexing
- Uses `nix-build` to evaluate NixOS options from any nixpkgs revision
- File indexing downloads tarball and stores allowed extensions (.nix, .json, .md, etc.)
- File indexing enabled by default (use `--no-files` to skip)
- Skips already-indexed revisions (use `--force` to re-index)
1. Should `index_revision` be blocking or async? (Currently: blocking, optimize later) ### Transports
2. Should we auto-update channel aliases or manual only? - **STDIO**: Default transport, line-delimited JSON-RPC (for CLI/desktop MCP clients)
- **HTTP**: Streamable HTTP transport with SSE (for web-based MCP clients)
- Session management with cryptographically secure IDs
- Configurable CORS (localhost-only by default)
- Optional TLS support
- SSE keepalive messages (15s default)
## Current Status ### Security
- Revision parameter validated against strict regex to prevent Nix injection
- Path traversal protection using `filepath.Clean()` and `filepath.IsAbs()`
- NixOS module supports `connectionStringFile` for PostgreSQL secrets
- Systemd service runs with extensive hardening options
- HTTP transport hardening:
- Request body size limit (1MB default)
- Server timeouts (read: 30s, write: 30s, idle: 120s, header: 10s)
- Maximum session limit (10,000 default)
- Origin validation for CORS
**Planning phase** - architecture and features defined, ready to begin implementation. ## CLI Commands
## Next Steps ```bash
nixos-options serve # Run MCP server on STDIO (default)
1. Design and implement database schema nixos-options serve --transport http # Run MCP server on HTTP
2. Set up project structure (directories, Go modules) nixos-options serve --transport http \
3. Implement database abstraction layer --http-address 0.0.0.0:8080 \
4. Implement MCP protocol basics --allowed-origins https://example.com # HTTP with custom config
5. Build indexing logic nixos-options index <revision> # Index a nixpkgs revision
6. Implement MCP tools nixos-options index --force <r> # Force re-index existing revision
7. Create Nix package in flake.nix nixos-options index --no-files # Skip file content indexing
8. Write tests and benchmarks nixos-options list # List indexed revisions
nixos-options search <query> # Search options
nixos-options get <option> # Get option details
nixos-options delete <revision> # Delete indexed revision
nixos-options --version # Show version
```
## Notes for Claude ## Notes for Claude
### Development Workflow
- **Always run `go fmt ./...` before committing Go code**
- **Run Go commands using `nix develop -c`** (e.g., `nix develop -c go test ./...`)
- **Use `nix run` to run binaries** (e.g., `nix run .#nixos-options -- serve`)
- File paths in responses should use format `path/to/file.go:123`
### User Preferences
- User prefers PostgreSQL over SQLite (has homelab infrastructure) - User prefers PostgreSQL over SQLite (has homelab infrastructure)
- User values good test coverage and benchmarking - User values good test coverage and benchmarking
- Project should remain generic to support future MCP servers - Project should remain generic to support future MCP servers
- Nix flake must provide importable packages for other repos
- Use `database/sql` interface for database abstraction ### Testing
- File paths in responses should use format `path/to/file.go:123` ```bash
- **Always run `go fmt ./...` before committing Go code** # Run all tests
- **Run Go commands using `nix develop -c`** (e.g., `nix develop -c go test ./...`) to ensure proper build environment with all dependencies nix develop -c go test ./... -short
- **Use `nix run` to run binaries** instead of `go build` followed by running the binary (e.g., `nix run .#nixos-options -- serve`)
# Run with verbose output
nix develop -c go test ./... -v
# Run benchmarks (requires nix-build)
nix develop -c go test -bench=. -benchtime=1x -timeout=30m ./internal/nixos/...
```
### Building
```bash
# Build with nix
nix build
# Run directly
nix run . -- serve
nix run . -- index nixos-unstable
```

View File

@@ -20,10 +20,10 @@ Search and query NixOS configuration options across multiple nixpkgs revisions.
```bash ```bash
# Build the package # Build the package
nix build github:torjus/labmcp nix build git+https://git.t-juice.club/torjus/labmcp
# Or run directly # Or run directly
nix run github:torjus/labmcp -- --help nix run git+https://git.t-juice.club/torjus/labmcp -- --help
``` ```
### From Source ### From Source
@@ -34,7 +34,7 @@ go install git.t-juice.club/torjus/labmcp/cmd/nixos-options@latest
## Usage ## Usage
### As MCP Server ### As MCP Server (STDIO)
Configure in your MCP client (e.g., Claude Desktop): Configure in your MCP client (e.g., Claude Desktop):
@@ -52,12 +52,52 @@ Configure in your MCP client (e.g., Claude Desktop):
} }
``` ```
Alternatively, if you have Nix installed, you can use the flake directly without installing the package:
```json
{
"mcpServers": {
"nixos-options": {
"command": "nix",
"args": ["run", "git+https://git.t-juice.club/torjus/labmcp", "--", "serve"],
"env": {
"NIXOS_OPTIONS_DATABASE": "sqlite:///path/to/nixos-options.db"
}
}
}
}
```
Then start the server: Then start the server:
```bash ```bash
nixos-options serve nixos-options serve
``` ```
### As MCP Server (HTTP)
The server can also run over HTTP with Server-Sent Events (SSE) for web-based MCP clients:
```bash
# Start HTTP server on default address (127.0.0.1:8080)
nixos-options serve --transport http
# Custom address and CORS configuration
nixos-options serve --transport http \
--http-address 0.0.0.0:8080 \
--allowed-origins https://example.com
# With TLS
nixos-options serve --transport http \
--tls-cert /path/to/cert.pem \
--tls-key /path/to/key.pem
```
HTTP transport endpoints:
- `POST /mcp` - JSON-RPC requests (returns `Mcp-Session-Id` header on initialize)
- `GET /mcp` - SSE stream for server notifications (requires `Mcp-Session-Id` header)
- `DELETE /mcp` - Terminate session
### CLI Examples ### CLI Examples
**Index a nixpkgs revision:** **Index a nixpkgs revision:**
@@ -154,7 +194,7 @@ A NixOS module is provided for running the MCP server as a systemd service.
```nix ```nix
{ {
inputs.labmcp.url = "github:torjus/labmcp"; inputs.labmcp.url = "git+https://git.t-juice.club/torjus/labmcp";
outputs = { self, nixpkgs, labmcp }: { outputs = { self, nixpkgs, labmcp }: {
nixosConfigurations.myhost = nixpkgs.lib.nixosSystem { nixosConfigurations.myhost = nixpkgs.lib.nixosSystem {
@@ -187,6 +227,14 @@ A NixOS module is provided for running the MCP server as a systemd service.
| `user` | string | `"nixos-options-mcp"` | User to run the service as | | `user` | string | `"nixos-options-mcp"` | User to run the service as |
| `group` | string | `"nixos-options-mcp"` | Group to run the service as | | `group` | string | `"nixos-options-mcp"` | Group to run the service as |
| `dataDir` | path | `/var/lib/nixos-options-mcp` | Directory for data storage | | `dataDir` | path | `/var/lib/nixos-options-mcp` | Directory for data storage |
| `http.address` | string | `"127.0.0.1:8080"` | HTTP listen address |
| `http.endpoint` | string | `"/mcp"` | HTTP endpoint path |
| `http.allowedOrigins` | list of string | `[]` | Allowed CORS origins (empty = localhost only) |
| `http.sessionTTL` | string | `"30m"` | Session timeout (Go duration format) |
| `http.tls.enable` | bool | `false` | Enable TLS |
| `http.tls.certFile` | path | `null` | TLS certificate file |
| `http.tls.keyFile` | path | `null` | TLS private key file |
| `openFirewall` | bool | `false` | Open firewall for HTTP port |
### PostgreSQL Example ### PostgreSQL Example

View File

@@ -5,7 +5,10 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"os/signal"
"strings" "strings"
"syscall"
"time"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@@ -36,7 +39,42 @@ func main() {
Commands: []*cli.Command{ Commands: []*cli.Command{
{ {
Name: "serve", Name: "serve",
Usage: "Run MCP server (stdio)", Usage: "Run MCP server",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "transport",
Aliases: []string{"t"},
Usage: "Transport type: 'stdio' or 'http'",
Value: "stdio",
},
&cli.StringFlag{
Name: "http-address",
Usage: "HTTP listen address",
Value: "127.0.0.1:8080",
},
&cli.StringFlag{
Name: "http-endpoint",
Usage: "HTTP endpoint path",
Value: "/mcp",
},
&cli.StringSliceFlag{
Name: "allowed-origins",
Usage: "Allowed Origin headers for CORS (can be specified multiple times)",
},
&cli.StringFlag{
Name: "tls-cert",
Usage: "TLS certificate file",
},
&cli.StringFlag{
Name: "tls-key",
Usage: "TLS key file",
},
&cli.DurationFlag{
Name: "session-ttl",
Usage: "Session TTL for HTTP transport",
Value: 30 * time.Minute,
},
},
Action: func(c *cli.Context) error { Action: func(c *cli.Context) error {
return runServe(c) return runServe(c)
}, },
@@ -145,7 +183,8 @@ func openStore(connStr string) (database.Store, error) {
} }
func runServe(c *cli.Context) error { func runServe(c *cli.Context) error {
ctx := context.Background() ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
store, err := openStore(c.String("database")) store, err := openStore(c.String("database"))
if err != nil { if err != nil {
@@ -163,8 +202,27 @@ func runServe(c *cli.Context) error {
indexer := nixos.NewIndexer(store) indexer := nixos.NewIndexer(store)
server.RegisterHandlers(indexer) server.RegisterHandlers(indexer)
logger.Println("Starting MCP server on stdio...") transport := c.String("transport")
return server.Run(ctx, os.Stdin, os.Stdout) switch transport {
case "stdio":
logger.Println("Starting MCP server on stdio...")
return server.Run(ctx, os.Stdin, os.Stdout)
case "http":
config := mcp.HTTPConfig{
Address: c.String("http-address"),
Endpoint: c.String("http-endpoint"),
AllowedOrigins: c.StringSlice("allowed-origins"),
SessionTTL: c.Duration("session-ttl"),
TLSCertFile: c.String("tls-cert"),
TLSKeyFile: c.String("tls-key"),
}
httpTransport := mcp.NewHTTPTransport(server, config)
return httpTransport.Run(ctx)
default:
return fmt.Errorf("unknown transport: %s (use 'stdio' or 'http')", transport)
}
} }
func runIndex(c *cli.Context, revision string, indexFiles bool, force bool) error { func runIndex(c *cli.Context, revision string, indexFiles bool, force bool) error {

View File

@@ -1,7 +1,6 @@
package mcp package mcp
import ( import (
"bufio"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -11,7 +10,7 @@ import (
"git.t-juice.club/torjus/labmcp/internal/database" "git.t-juice.club/torjus/labmcp/internal/database"
) )
// Server is an MCP server that handles JSON-RPC requests over stdio. // Server is an MCP server that handles JSON-RPC requests.
type Server struct { type Server struct {
store database.Store store database.Store
tools map[string]ToolHandler tools map[string]ToolHandler
@@ -41,53 +40,34 @@ func (s *Server) registerTools() {
// Tools will be implemented in handlers.go // Tools will be implemented in handlers.go
} }
// Run starts the server, reading from r and writing to w. // Run starts the server using STDIO transport (backward compatibility).
func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error { func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error {
scanner := bufio.NewScanner(r) transport := NewStdioTransport(s, r, w)
encoder := json.NewEncoder(w) return transport.Run(ctx)
}
for scanner.Scan() { // HandleMessage parses a JSON-RPC message and returns the response.
select { // Returns (nil, nil) for notifications that don't require a response.
case <-ctx.Done(): func (s *Server) HandleMessage(ctx context.Context, data []byte) (*Response, error) {
return ctx.Err() var req Request
default: if err := json.Unmarshal(data, &req); err != nil {
} return &Response{
JSONRPC: "2.0",
line := scanner.Bytes() Error: &Error{
if len(line) == 0 { Code: ParseError,
continue Message: "Parse error",
} Data: err.Error(),
},
var req Request }, nil
if err := json.Unmarshal(line, &req); err != nil {
s.logger.Printf("Failed to parse request: %v", err)
resp := Response{
JSONRPC: "2.0",
Error: &Error{
Code: ParseError,
Message: "Parse error",
Data: err.Error(),
},
}
if err := encoder.Encode(resp); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
continue
}
resp := s.handleRequest(ctx, &req)
if resp != nil {
if err := encoder.Encode(resp); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
} }
if err := scanner.Err(); err != nil { return s.HandleRequest(ctx, &req), nil
return fmt.Errorf("scanner error: %w", err) }
}
return nil // HandleRequest processes a single request and returns a response.
// Returns nil for notifications that don't require a response.
func (s *Server) HandleRequest(ctx context.Context, req *Request) *Response {
return s.handleRequest(ctx, req)
} }
// handleRequest processes a single request and returns a response. // handleRequest processes a single request and returns a response.

231
internal/mcp/session.go Normal file
View File

@@ -0,0 +1,231 @@
package mcp
import (
"crypto/rand"
"encoding/hex"
"fmt"
"sync"
"time"
)
// Session represents an MCP client session.
type Session struct {
ID string
CreatedAt time.Time
LastActivity time.Time
Initialized bool
// notifications is a channel for server-initiated notifications.
notifications chan *Response
mu sync.RWMutex
}
// NewSession creates a new session with a cryptographically secure random ID.
func NewSession() (*Session, error) {
id, err := generateSessionID()
if err != nil {
return nil, err
}
now := time.Now()
return &Session{
ID: id,
CreatedAt: now,
LastActivity: now,
notifications: make(chan *Response, 100),
}, nil
}
// Touch updates the session's last activity time.
func (s *Session) Touch() {
s.mu.Lock()
defer s.mu.Unlock()
s.LastActivity = time.Now()
}
// SetInitialized marks the session as initialized.
func (s *Session) SetInitialized() {
s.mu.Lock()
defer s.mu.Unlock()
s.Initialized = true
}
// IsInitialized returns whether the session has been initialized.
func (s *Session) IsInitialized() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.Initialized
}
// Notifications returns the channel for server-initiated notifications.
func (s *Session) Notifications() <-chan *Response {
return s.notifications
}
// SendNotification sends a notification to the session.
// Returns false if the channel is full.
func (s *Session) SendNotification(notification *Response) bool {
select {
case s.notifications <- notification:
return true
default:
return false
}
}
// Close closes the session's notification channel.
func (s *Session) Close() {
close(s.notifications)
}
// SessionStore manages active sessions with TTL-based cleanup.
type SessionStore struct {
sessions map[string]*Session
ttl time.Duration
maxSessions int
mu sync.RWMutex
stopClean chan struct{}
cleanDone chan struct{}
}
// ErrTooManySessions is returned when the session limit is reached.
var ErrTooManySessions = fmt.Errorf("too many active sessions")
// DefaultMaxSessions is the default maximum number of concurrent sessions.
const DefaultMaxSessions = 10000
// NewSessionStore creates a new session store with the given TTL.
func NewSessionStore(ttl time.Duration) *SessionStore {
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
}
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
if maxSessions <= 0 {
maxSessions = DefaultMaxSessions
}
s := &SessionStore{
sessions: make(map[string]*Session),
ttl: ttl,
maxSessions: maxSessions,
stopClean: make(chan struct{}),
cleanDone: make(chan struct{}),
}
go s.cleanupLoop()
return s
}
// Create creates a new session and adds it to the store.
// Returns ErrTooManySessions if the maximum session limit is reached.
func (s *SessionStore) Create() (*Session, error) {
s.mu.Lock()
defer s.mu.Unlock()
// Check session limit
if len(s.sessions) >= s.maxSessions {
return nil, ErrTooManySessions
}
session, err := NewSession()
if err != nil {
return nil, err
}
s.sessions[session.ID] = session
return session, nil
}
// Get retrieves a session by ID. Returns nil if not found or expired.
func (s *SessionStore) Get(id string) *Session {
s.mu.RLock()
defer s.mu.RUnlock()
session, ok := s.sessions[id]
if !ok {
return nil
}
// Check if expired
session.mu.RLock()
expired := time.Since(session.LastActivity) > s.ttl
session.mu.RUnlock()
if expired {
return nil
}
return session
}
// Delete removes a session from the store.
func (s *SessionStore) Delete(id string) bool {
s.mu.Lock()
defer s.mu.Unlock()
session, ok := s.sessions[id]
if !ok {
return false
}
session.Close()
delete(s.sessions, id)
return true
}
// Count returns the number of active sessions.
func (s *SessionStore) Count() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.sessions)
}
// Stop stops the cleanup goroutine and waits for it to finish.
func (s *SessionStore) Stop() {
close(s.stopClean)
<-s.cleanDone
}
// cleanupLoop periodically removes expired sessions.
func (s *SessionStore) cleanupLoop() {
defer close(s.cleanDone)
ticker := time.NewTicker(s.ttl / 2)
defer ticker.Stop()
for {
select {
case <-s.stopClean:
return
case <-ticker.C:
s.cleanup()
}
}
}
// cleanup removes expired sessions.
func (s *SessionStore) cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
now := time.Now()
for id, session := range s.sessions {
session.mu.RLock()
expired := now.Sub(session.LastActivity) > s.ttl
session.mu.RUnlock()
if expired {
session.Close()
delete(s.sessions, id)
}
}
}
// generateSessionID generates a cryptographically secure random session ID.
func generateSessionID() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -0,0 +1,337 @@
package mcp
import (
"sync"
"testing"
"time"
)
func TestNewSession(t *testing.T) {
session, err := NewSession()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if session.ID == "" {
t.Error("Session ID should not be empty")
}
if len(session.ID) != 32 {
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
}
if session.Initialized {
t.Error("New session should not be initialized")
}
}
func TestSessionTouch(t *testing.T) {
session, _ := NewSession()
originalActivity := session.LastActivity
time.Sleep(10 * time.Millisecond)
session.Touch()
if !session.LastActivity.After(originalActivity) {
t.Error("Touch should update LastActivity")
}
}
func TestSessionInitialized(t *testing.T) {
session, _ := NewSession()
if session.IsInitialized() {
t.Error("New session should not be initialized")
}
session.SetInitialized()
if !session.IsInitialized() {
t.Error("Session should be initialized after SetInitialized")
}
}
func TestSessionNotifications(t *testing.T) {
session, _ := NewSession()
defer session.Close()
notification := &Response{JSONRPC: "2.0", ID: 1}
if !session.SendNotification(notification) {
t.Error("SendNotification should return true on success")
}
select {
case received := <-session.Notifications():
if received.ID != notification.ID {
t.Error("Received notification should match sent")
}
case <-time.After(100 * time.Millisecond):
t.Error("Should receive notification")
}
}
func TestSessionStoreCreate(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
session, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if store.Count() != 1 {
t.Errorf("Store should have 1 session, got %d", store.Count())
}
// Verify we can retrieve it
retrieved := store.Get(session.ID)
if retrieved == nil {
t.Error("Should be able to retrieve created session")
}
if retrieved.ID != session.ID {
t.Error("Retrieved session ID should match")
}
}
func TestSessionStoreGet(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Get non-existent session
if store.Get("nonexistent") != nil {
t.Error("Should return nil for non-existent session")
}
// Create and retrieve
session, _ := store.Create()
retrieved := store.Get(session.ID)
if retrieved == nil {
t.Error("Should find created session")
}
}
func TestSessionStoreDelete(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
session, _ := store.Create()
if store.Count() != 1 {
t.Error("Should have 1 session after create")
}
if !store.Delete(session.ID) {
t.Error("Delete should return true for existing session")
}
if store.Count() != 0 {
t.Error("Should have 0 sessions after delete")
}
if store.Delete(session.ID) {
t.Error("Delete should return false for non-existent session")
}
}
func TestSessionStoreTTLExpiration(t *testing.T) {
ttl := 50 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
session, _ := store.Create()
// Should be retrievable immediately
if store.Get(session.ID) == nil {
t.Error("Session should be retrievable immediately")
}
// Wait for expiration
time.Sleep(ttl + 10*time.Millisecond)
// Should not be retrievable after TTL
if store.Get(session.ID) != nil {
t.Error("Expired session should not be retrievable")
}
}
func TestSessionStoreTTLRefresh(t *testing.T) {
ttl := 100 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
session, _ := store.Create()
// Touch the session before TTL expires
time.Sleep(60 * time.Millisecond)
session.Touch()
// Wait past original TTL but not past refreshed TTL
time.Sleep(60 * time.Millisecond)
// Should still be retrievable because we touched it
if store.Get(session.ID) == nil {
t.Error("Touched session should still be retrievable")
}
}
func TestSessionStoreCleanup(t *testing.T) {
ttl := 50 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
// Create multiple sessions
for i := 0; i < 5; i++ {
store.Create()
}
if store.Count() != 5 {
t.Errorf("Should have 5 sessions, got %d", store.Count())
}
// Wait for cleanup to run (runs at ttl/2 intervals)
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
// All sessions should be cleaned up
if store.Count() != 0 {
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
}
}
func TestSessionStoreConcurrency(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
var wg sync.WaitGroup
sessionIDs := make(chan string, 100)
// Create sessions concurrently
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
session, err := store.Create()
if err != nil {
t.Errorf("Failed to create session: %v", err)
return
}
sessionIDs <- session.ID
}()
}
wg.Wait()
close(sessionIDs)
// Verify all sessions were created
if store.Count() != 50 {
t.Errorf("Should have 50 sessions, got %d", store.Count())
}
// Read and delete concurrently
var ids []string
for id := range sessionIDs {
ids = append(ids, id)
}
for _, id := range ids {
wg.Add(2)
go func(id string) {
defer wg.Done()
store.Get(id)
}(id)
go func(id string) {
defer wg.Done()
store.Delete(id)
}(id)
}
wg.Wait()
}
func TestSessionStoreMaxSessions(t *testing.T) {
maxSessions := 5
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Create sessions up to limit
for i := 0; i < maxSessions; i++ {
_, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session %d: %v", i, err)
}
}
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
}
// Try to create one more - should fail
_, err := store.Create()
if err != ErrTooManySessions {
t.Errorf("Expected ErrTooManySessions, got %v", err)
}
// Count should still be at max
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
}
}
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
maxSessions := 3
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Fill up the store
sessions := make([]*Session, maxSessions)
for i := 0; i < maxSessions; i++ {
s, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
sessions[i] = s
}
// Should be full
_, err := store.Create()
if err != ErrTooManySessions {
t.Error("Expected ErrTooManySessions when full")
}
// Delete one session
store.Delete(sessions[0].ID)
// Should be able to create again
_, err = store.Create()
if err != nil {
t.Errorf("Should be able to create after deletion: %v", err)
}
}
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Just verify it uses the default (don't create 10000 sessions)
if store.maxSessions != DefaultMaxSessions {
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
}
}
func TestGenerateSessionID(t *testing.T) {
ids := make(map[string]bool)
// Generate 1000 IDs and ensure uniqueness
for i := 0; i < 1000; i++ {
id, err := generateSessionID()
if err != nil {
t.Fatalf("Failed to generate session ID: %v", err)
}
if len(id) != 32 {
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
}
if ids[id] {
t.Error("Generated duplicate session ID")
}
ids[id] = true
}
}

10
internal/mcp/transport.go Normal file
View File

@@ -0,0 +1,10 @@
package mcp
import "context"
// Transport defines the interface for MCP server transports.
type Transport interface {
// Run starts the transport and blocks until the context is cancelled
// or an error occurs.
Run(ctx context.Context) error
}

View File

@@ -0,0 +1,448 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
)
// HTTPConfig configures the HTTP transport.
type HTTPConfig struct {
Address string // Listen address (e.g., "127.0.0.1:8080")
Endpoint string // MCP endpoint path (e.g., "/mcp")
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
SessionTTL time.Duration // Session TTL (default: 30 minutes)
MaxSessions int // Maximum concurrent sessions (default: 10000)
TLSCertFile string // TLS certificate file (optional)
TLSKeyFile string // TLS key file (optional)
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
ReadTimeout time.Duration // HTTP server read timeout (default: 30s)
WriteTimeout time.Duration // HTTP server write timeout (default: 30s)
IdleTimeout time.Duration // HTTP server idle timeout (default: 120s)
ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s)
SSEKeepAlive time.Duration // SSE keepalive interval (default: 15s, 0 to disable)
}
const (
// DefaultMaxRequestSize is the default maximum request body size (1MB).
DefaultMaxRequestSize = 1 << 20 // 1MB
// Default HTTP server timeouts
DefaultReadTimeout = 30 * time.Second
DefaultWriteTimeout = 30 * time.Second
DefaultIdleTimeout = 120 * time.Second
DefaultReadHeaderTimeout = 10 * time.Second
// DefaultSSEKeepAlive is the default interval for SSE keepalive messages.
// These are sent as SSE comments to keep the connection alive through
// proxies and load balancers, and to detect stale connections.
DefaultSSEKeepAlive = 15 * time.Second
)
// HTTPTransport implements the MCP Streamable HTTP transport.
type HTTPTransport struct {
server *Server
config HTTPConfig
sessions *SessionStore
}
// NewHTTPTransport creates a new HTTP transport.
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
if config.Address == "" {
config.Address = "127.0.0.1:8080"
}
if config.Endpoint == "" {
config.Endpoint = "/mcp"
}
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
if config.MaxSessions == 0 {
config.MaxSessions = DefaultMaxSessions
}
if config.MaxRequestSize == 0 {
config.MaxRequestSize = DefaultMaxRequestSize
}
if config.ReadTimeout == 0 {
config.ReadTimeout = DefaultReadTimeout
}
if config.WriteTimeout == 0 {
config.WriteTimeout = DefaultWriteTimeout
}
if config.IdleTimeout == 0 {
config.IdleTimeout = DefaultIdleTimeout
}
if config.ReadHeaderTimeout == 0 {
config.ReadHeaderTimeout = DefaultReadHeaderTimeout
}
// SSEKeepAlive: 0 means use default, negative means disabled
if config.SSEKeepAlive == 0 {
config.SSEKeepAlive = DefaultSSEKeepAlive
}
return &HTTPTransport{
server: server,
config: config,
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
}
}
// Run starts the HTTP server and blocks until the context is cancelled.
func (t *HTTPTransport) Run(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
httpServer := &http.Server{
Addr: t.config.Address,
Handler: mux,
ReadTimeout: t.config.ReadTimeout,
WriteTimeout: t.config.WriteTimeout,
IdleTimeout: t.config.IdleTimeout,
ReadHeaderTimeout: t.config.ReadHeaderTimeout,
BaseContext: func(l net.Listener) context.Context {
return ctx
},
}
// Graceful shutdown on context cancellation
go func() {
<-ctx.Done()
t.server.logger.Println("Shutting down HTTP server...")
t.sessions.Stop()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
t.server.logger.Printf("HTTP server shutdown error: %v", err)
}
}()
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
var err error
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
} else {
err = httpServer.ListenAndServe()
}
if err == http.ErrServerClosed {
return nil
}
return err
}
// handleMCP routes requests based on HTTP method.
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
// Validate Origin header
if !t.isOriginAllowed(r) {
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
return
}
switch r.Method {
case http.MethodPost:
t.handlePost(w, r)
case http.MethodGet:
t.handleGet(w, r)
case http.MethodDelete:
t.handleDelete(w, r)
case http.MethodOptions:
t.handleOptions(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handlePost handles JSON-RPC requests.
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
// Limit request body size to prevent memory exhaustion attacks
r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize)
// Read request body
body, err := io.ReadAll(r.Body)
if err != nil {
// Check if this is a size limit error
if err.Error() == "http: request body too large" {
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
return
}
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
// Parse request to check method
var req Request
if err := json.Unmarshal(body, &req); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{
JSONRPC: "2.0",
Error: &Error{
Code: ParseError,
Message: "Parse error",
Data: err.Error(),
},
})
return
}
// Handle initialize request - create session
if req.Method == MethodInitialize {
t.handleInitialize(w, r, &req)
return
}
// All other requests require a valid session
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
session := t.sessions.Get(sessionID)
if session == nil {
http.Error(w, "Invalid or expired session", http.StatusNotFound)
return
}
// Update session activity
session.Touch()
// Handle notifications (no response expected)
if req.Method == MethodInitialized {
session.SetInitialized()
w.WriteHeader(http.StatusAccepted)
return
}
// Check if this is a notification (no ID)
if req.ID == nil {
w.WriteHeader(http.StatusAccepted)
return
}
// Process the request
resp := t.server.HandleRequest(r.Context(), &req)
if resp == nil {
w.WriteHeader(http.StatusAccepted)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
// handleInitialize handles the initialize request and creates a new session.
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
// Create a new session
session, err := t.sessions.Create()
if err != nil {
if err == ErrTooManySessions {
t.server.logger.Printf("Session limit reached")
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
return
}
t.server.logger.Printf("Failed to create session: %v", err)
http.Error(w, "Failed to create session", http.StatusInternalServerError)
return
}
// Process initialize request
resp := t.server.HandleRequest(r.Context(), req)
if resp == nil {
t.sessions.Delete(session.ID)
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
return
}
// If initialize failed, clean up session
if resp.Error != nil {
t.sessions.Delete(session.ID)
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Mcp-Session-Id", session.ID)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
// handleGet handles SSE stream for server-initiated notifications.
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
session := t.sessions.Get(sessionID)
if session == nil {
http.Error(w, "Invalid or expired session", http.StatusNotFound)
return
}
// Check if client accepts SSE
accept := r.Header.Get("Accept")
if !strings.Contains(accept, "text/event-stream") {
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
return
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.WriteHeader(http.StatusOK)
// Flush headers
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
flusher.Flush()
// Use ResponseController to manage write deadlines for long-lived SSE connections
rc := http.NewResponseController(w)
// Set up keepalive ticker if enabled
var keepaliveTicker *time.Ticker
var keepaliveChan <-chan time.Time
if t.config.SSEKeepAlive > 0 {
keepaliveTicker = time.NewTicker(t.config.SSEKeepAlive)
keepaliveChan = keepaliveTicker.C
defer keepaliveTicker.Stop()
}
// Stream notifications
ctx := r.Context()
for {
select {
case <-ctx.Done():
return
case <-keepaliveChan:
// Send SSE comment as keepalive (ignored by clients)
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
t.server.logger.Printf("Failed to set write deadline: %v", err)
}
if _, err := fmt.Fprintf(w, ":keepalive\n\n"); err != nil {
// Write failed, connection likely closed
return
}
flusher.Flush()
case notification, ok := <-session.Notifications():
if !ok {
// Session closed
return
}
// Extend write deadline before each write
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
t.server.logger.Printf("Failed to set write deadline: %v", err)
}
data, err := json.Marshal(notification)
if err != nil {
t.server.logger.Printf("Failed to marshal notification: %v", err)
continue
}
// Write SSE event
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
// Write failed, connection likely closed
return
}
flusher.Flush()
// Touch session to keep it alive
session.Touch()
}
}
}
// handleDelete terminates a session.
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
sessionID := r.Header.Get("Mcp-Session-Id")
if sessionID == "" {
http.Error(w, "Session ID required", http.StatusBadRequest)
return
}
if t.sessions.Delete(sessionID) {
w.WriteHeader(http.StatusNoContent)
} else {
http.Error(w, "Session not found", http.StatusNotFound)
}
}
// handleOptions handles CORS preflight requests.
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && t.isOriginAllowed(r) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
w.Header().Set("Access-Control-Max-Age", "86400")
}
w.WriteHeader(http.StatusNoContent)
}
// isOriginAllowed checks if the request origin is allowed.
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
origin := r.Header.Get("Origin")
// No Origin header (same-origin request) is always allowed
if origin == "" {
return true
}
// If no allowed origins configured, only allow localhost
if len(t.config.AllowedOrigins) == 0 {
return isLocalhostOrigin(origin)
}
// Check against allowed origins
for _, allowed := range t.config.AllowedOrigins {
if allowed == "*" || allowed == origin {
return true
}
}
return false
}
// isLocalhostOrigin checks if the origin is a localhost address.
func isLocalhostOrigin(origin string) bool {
origin = strings.ToLower(origin)
// Check for localhost patterns (must be followed by :, /, or end of string)
localhostPatterns := []string{
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
}
for _, pattern := range localhostPatterns {
if origin == pattern {
return true
}
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
return true
}
}
return false
}

View File

@@ -0,0 +1,752 @@
package mcp
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// testHTTPTransport creates a transport with a test server
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
// Use a mock store
server := NewServer(nil, log.New(io.Discard, "", 0))
if config.SessionTTL == 0 {
config.SessionTTL = 30 * time.Minute
}
transport := NewHTTPTransport(server, config)
// Create test server
mux := http.NewServeMux()
endpoint := config.Endpoint
if endpoint == "" {
endpoint = "/mcp"
}
mux.HandleFunc(endpoint, transport.handleMCP)
ts := httptest.NewServer(mux)
t.Cleanup(func() {
ts.Close()
transport.sessions.Stop()
})
return transport, ts
}
func TestHTTPTransportInitialize(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send initialize request
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200, got %d", resp.StatusCode)
}
// Check session ID header
sessionID := resp.Header.Get("Mcp-Session-Id")
if sessionID == "" {
t.Error("Expected Mcp-Session-Id header")
}
if len(sessionID) != 32 {
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
}
// Check response body
var initResp Response
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if initResp.Error != nil {
t.Errorf("Initialize failed: %v", initResp.Error)
}
}
func TestHTTPTransportSessionRequired(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send tools/list without session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportInvalidSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send request with invalid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportValidSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
// Create session manually
session, _ := transport.sessions.Create()
// Send tools/list with valid session
listReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodToolsList,
}
body, _ := json.Marshal(listReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportNotificationAccepted(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Send notification (no ID)
notification := Request{
JSONRPC: "2.0",
Method: MethodInitialized,
}
body, _ := json.Marshal(notification)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
}
// Verify session is marked as initialized
if !session.IsInitialized() {
t.Error("Session should be marked as initialized")
}
}
func TestHTTPTransportDeleteSession(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Delete session
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
}
// Verify session is gone
if transport.sessions.Get(session.ID) != nil {
t.Error("Session should be deleted")
}
}
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", "nonexistent")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOriginValidation(t *testing.T) {
tests := []struct {
name string
allowedOrigins []string
origin string
expectAllowed bool
}{
{
name: "no origin header",
allowedOrigins: nil,
origin: "",
expectAllowed: true,
},
{
name: "localhost allowed by default",
allowedOrigins: nil,
origin: "http://localhost:3000",
expectAllowed: true,
},
{
name: "127.0.0.1 allowed by default",
allowedOrigins: nil,
origin: "http://127.0.0.1:8080",
expectAllowed: true,
},
{
name: "external origin blocked by default",
allowedOrigins: nil,
origin: "http://evil.com",
expectAllowed: false,
},
{
name: "explicit allow",
allowedOrigins: []string{"http://example.com"},
origin: "http://example.com",
expectAllowed: true,
},
{
name: "explicit allow wildcard",
allowedOrigins: []string{"*"},
origin: "http://anything.com",
expectAllowed: true,
},
{
name: "not in allowed list",
allowedOrigins: []string{"http://example.com"},
origin: "http://other.com",
expectAllowed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: tt.allowedOrigins,
})
// Use initialize since it doesn't require a session
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if tt.origin != "" {
req.Header.Set("Origin", tt.origin)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
t.Error("Expected request to be allowed but was forbidden")
}
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
}
})
}
}
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// GET without Accept: text/event-stream
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotAcceptable {
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
}
}
func TestHTTPTransportSSEStream(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{})
session, _ := transport.sessions.Create()
// Start SSE stream in goroutine
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected 200, got %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != "text/event-stream" {
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
}
// Send a notification
notification := &Response{
JSONRPC: "2.0",
ID: 42,
Result: map[string]string{"test": "data"},
}
session.SendNotification(notification)
// Read the SSE event
buf := make([]byte, 1024)
n, err := resp.Body.Read(buf)
if err != nil && err != io.EOF {
t.Fatalf("Failed to read SSE event: %v", err)
}
data := string(buf[:n])
if !strings.HasPrefix(data, "data: ") {
t.Errorf("Expected SSE data event, got: %s", data)
}
// Parse the JSON from the SSE event
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
var received Response
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
t.Fatalf("Failed to parse notification JSON: %v", err)
}
// JSON unmarshal converts numbers to float64, so compare as float64
receivedID, ok := received.ID.(float64)
if !ok {
t.Fatalf("Expected numeric ID, got %T", received.ID)
}
if int(receivedID) != 42 {
t.Errorf("Expected notification ID 42, got %v", receivedID)
}
}
func TestHTTPTransportSSEKeepalive(t *testing.T) {
transport, ts := testHTTPTransport(t, HTTPConfig{
SSEKeepAlive: 50 * time.Millisecond, // Short interval for testing
})
session, _ := transport.sessions.Create()
// Start SSE stream
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
req.Header.Set("Mcp-Session-Id", session.ID)
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("Expected 200, got %d", resp.StatusCode)
}
// Read with timeout - should receive keepalive within 100ms
buf := make([]byte, 256)
done := make(chan struct{})
var readData string
var readErr error
go func() {
n, err := resp.Body.Read(buf)
readData = string(buf[:n])
readErr = err
close(done)
}()
select {
case <-done:
if readErr != nil && readErr.Error() != "EOF" {
t.Fatalf("Read error: %v", readErr)
}
// Should receive SSE comment keepalive
if !strings.Contains(readData, ":keepalive") {
t.Errorf("Expected keepalive comment, got: %q", readData)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timeout waiting for keepalive")
}
}
func TestHTTPTransportSSEKeepaliveDisabled(t *testing.T) {
server := NewServer(nil, log.New(io.Discard, "", 0))
config := HTTPConfig{
SSEKeepAlive: -1, // Explicitly disabled
}
transport := NewHTTPTransport(server, config)
defer transport.sessions.Stop()
// When SSEKeepAlive is negative, it should remain negative (disabled)
if transport.config.SSEKeepAlive != -1 {
t.Errorf("Expected SSEKeepAlive to remain -1 (disabled), got %v", transport.config.SSEKeepAlive)
}
}
func TestHTTPTransportParseError(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
// Send invalid JSON
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
}
var jsonResp Response
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if jsonResp.Error == nil {
t.Error("Expected JSON-RPC error for parse error")
}
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
}
}
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{})
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusMethodNotAllowed {
t.Errorf("Expected 405, got %d", resp.StatusCode)
}
}
func TestHTTPTransportOptionsRequest(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
AllowedOrigins: []string{"http://example.com"},
})
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
req.Header.Set("Origin", "http://example.com")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusNoContent {
t.Errorf("Expected 204, got %d", resp.StatusCode)
}
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
t.Error("Expected CORS origin header")
}
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
t.Error("Expected CORS methods header")
}
}
func TestHTTPTransportDefaultConfig(t *testing.T) {
server := NewServer(nil, log.New(io.Discard, "", 0))
transport := NewHTTPTransport(server, HTTPConfig{})
// Verify defaults are applied
if transport.config.Address != "127.0.0.1:8080" {
t.Errorf("Expected default address 127.0.0.1:8080, got %s", transport.config.Address)
}
if transport.config.Endpoint != "/mcp" {
t.Errorf("Expected default endpoint /mcp, got %s", transport.config.Endpoint)
}
if transport.config.SessionTTL != 30*time.Minute {
t.Errorf("Expected default session TTL 30m, got %v", transport.config.SessionTTL)
}
if transport.config.MaxRequestSize != DefaultMaxRequestSize {
t.Errorf("Expected default max request size %d, got %d", DefaultMaxRequestSize, transport.config.MaxRequestSize)
}
if transport.config.ReadTimeout != DefaultReadTimeout {
t.Errorf("Expected default read timeout %v, got %v", DefaultReadTimeout, transport.config.ReadTimeout)
}
if transport.config.WriteTimeout != DefaultWriteTimeout {
t.Errorf("Expected default write timeout %v, got %v", DefaultWriteTimeout, transport.config.WriteTimeout)
}
if transport.config.IdleTimeout != DefaultIdleTimeout {
t.Errorf("Expected default idle timeout %v, got %v", DefaultIdleTimeout, transport.config.IdleTimeout)
}
if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout {
t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout)
}
if transport.config.SSEKeepAlive != DefaultSSEKeepAlive {
t.Errorf("Expected default SSE keepalive %v, got %v", DefaultSSEKeepAlive, transport.config.SSEKeepAlive)
}
transport.sessions.Stop()
}
func TestHTTPTransportCustomConfig(t *testing.T) {
server := NewServer(nil, log.New(io.Discard, "", 0))
config := HTTPConfig{
Address: "0.0.0.0:9090",
Endpoint: "/api/mcp",
SessionTTL: 1 * time.Hour,
MaxRequestSize: 5 << 20, // 5MB
ReadTimeout: 60 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 300 * time.Second,
ReadHeaderTimeout: 20 * time.Second,
}
transport := NewHTTPTransport(server, config)
// Verify custom values are preserved
if transport.config.Address != "0.0.0.0:9090" {
t.Errorf("Expected custom address, got %s", transport.config.Address)
}
if transport.config.Endpoint != "/api/mcp" {
t.Errorf("Expected custom endpoint, got %s", transport.config.Endpoint)
}
if transport.config.SessionTTL != 1*time.Hour {
t.Errorf("Expected custom session TTL, got %v", transport.config.SessionTTL)
}
if transport.config.MaxRequestSize != 5<<20 {
t.Errorf("Expected custom max request size, got %d", transport.config.MaxRequestSize)
}
if transport.config.ReadTimeout != 60*time.Second {
t.Errorf("Expected custom read timeout, got %v", transport.config.ReadTimeout)
}
if transport.config.WriteTimeout != 60*time.Second {
t.Errorf("Expected custom write timeout, got %v", transport.config.WriteTimeout)
}
if transport.config.IdleTimeout != 300*time.Second {
t.Errorf("Expected custom idle timeout, got %v", transport.config.IdleTimeout)
}
if transport.config.ReadHeaderTimeout != 20*time.Second {
t.Errorf("Expected custom read header timeout, got %v", transport.config.ReadHeaderTimeout)
}
transport.sessions.Stop()
}
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 100, // Very small limit for testing
})
// Create a request body larger than the limit
largeBody := make([]byte, 200)
for i := range largeBody {
largeBody[i] = 'x'
}
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusRequestEntityTooLarge {
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
}
}
func TestHTTPTransportSessionLimitReached(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxSessions: 2, // Very low limit for testing
})
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
// Create sessions up to the limit
for i := 0; i < 2; i++ {
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request %d failed: %v", i, err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
}
}
// Third request should fail with 503
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
}
}
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
_, ts := testHTTPTransport(t, HTTPConfig{
MaxRequestSize: 10000, // Reasonable limit
})
// Send initialize request (should be well within limit)
initReq := Request{
JSONRPC: "2.0",
ID: 1,
Method: MethodInitialize,
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
}
body, _ := json.Marshal(initReq)
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
}
}
func TestIsLocalhostOrigin(t *testing.T) {
tests := []struct {
origin string
expected bool
}{
{"http://localhost", true},
{"http://localhost:3000", true},
{"https://localhost", true},
{"https://localhost:8443", true},
{"http://127.0.0.1", true},
{"http://127.0.0.1:8080", true},
{"https://127.0.0.1", true},
{"http://[::1]", true},
{"http://[::1]:8080", true},
{"https://[::1]", true},
{"http://example.com", false},
{"https://example.com", false},
{"http://localhost.evil.com", false},
{"http://192.168.1.1", false},
}
for _, tt := range tests {
t.Run(tt.origin, func(t *testing.T) {
result := isLocalhostOrigin(tt.origin)
if result != tt.expected {
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,63 @@
package mcp
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
)
// StdioTransport implements the MCP protocol over STDIO using line-delimited JSON-RPC.
type StdioTransport struct {
server *Server
reader io.Reader
writer io.Writer
}
// NewStdioTransport creates a new STDIO transport.
func NewStdioTransport(server *Server, r io.Reader, w io.Writer) *StdioTransport {
return &StdioTransport{
server: server,
reader: r,
writer: w,
}
}
// Run starts the STDIO transport, reading line-delimited JSON-RPC from the reader
// and writing responses to the writer.
func (t *StdioTransport) Run(ctx context.Context) error {
scanner := bufio.NewScanner(t.reader)
encoder := json.NewEncoder(t.writer)
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line := scanner.Bytes()
if len(line) == 0 {
continue
}
resp, err := t.server.HandleMessage(ctx, line)
if err != nil {
t.server.logger.Printf("Error handling message: %v", err)
continue
}
if resp != nil {
if err := encoder.Encode(resp); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %w", err)
}
return nil
}

View File

@@ -89,13 +89,56 @@ in
''; '';
}; };
http = {
address = lib.mkOption {
type = lib.types.str;
default = "127.0.0.1:8080";
description = "HTTP listen address for the MCP server.";
};
endpoint = lib.mkOption {
type = lib.types.str;
default = "/mcp";
description = "HTTP endpoint path for MCP requests.";
};
allowedOrigins = lib.mkOption {
type = lib.types.listOf lib.types.str;
default = [ ];
example = [ "http://localhost:3000" "https://example.com" ];
description = ''
Allowed Origin headers for CORS.
Empty list means only localhost origins are allowed.
'';
};
sessionTTL = lib.mkOption {
type = lib.types.str;
default = "30m";
description = "Session TTL for HTTP transport (Go duration format).";
};
tls = {
enable = lib.mkEnableOption "TLS for HTTP transport";
certFile = lib.mkOption {
type = lib.types.nullOr lib.types.path;
default = null;
description = "Path to TLS certificate file.";
};
keyFile = lib.mkOption {
type = lib.types.nullOr lib.types.path;
default = null;
description = "Path to TLS private key file.";
};
};
};
openFirewall = lib.mkOption { openFirewall = lib.mkOption {
type = lib.types.bool; type = lib.types.bool;
default = false; default = false;
description = '' description = "Whether to open the firewall for the MCP HTTP server.";
Whether to open the firewall for the MCP server.
Note: MCP typically runs over stdio, so this is usually not needed.
'';
}; };
}; };
@@ -111,6 +154,10 @@ in
assertion = cfg.database.connectionString == "" || cfg.database.connectionStringFile == null; assertion = cfg.database.connectionString == "" || cfg.database.connectionStringFile == null;
message = "services.nixos-options-mcp.database: connectionString and connectionStringFile are mutually exclusive"; message = "services.nixos-options-mcp.database: connectionString and connectionStringFile are mutually exclusive";
} }
{
assertion = !cfg.http.tls.enable || (cfg.http.tls.certFile != null && cfg.http.tls.keyFile != null);
message = "services.nixos-options-mcp.http.tls: both certFile and keyFile must be set when TLS is enabled";
}
]; ];
users.users.${cfg.user} = lib.mkIf (cfg.user == "nixos-options-mcp") { users.users.${cfg.user} = lib.mkIf (cfg.user == "nixos-options-mcp") {
@@ -145,6 +192,19 @@ in
nixos-options index "${rev}" || true nixos-options index "${rev}" || true
'') cfg.indexOnStart} '') cfg.indexOnStart}
''; '';
# Build HTTP transport flags
httpFlags = lib.concatStringsSep " " ([
"--transport http"
"--http-address '${cfg.http.address}'"
"--http-endpoint '${cfg.http.endpoint}'"
"--session-ttl '${cfg.http.sessionTTL}'"
] ++ lib.optionals (cfg.http.allowedOrigins != []) (
map (origin: "--allowed-origins '${origin}'") cfg.http.allowedOrigins
) ++ lib.optionals cfg.http.tls.enable [
"--tls-cert '${cfg.http.tls.certFile}'"
"--tls-key '${cfg.http.tls.keyFile}'"
]);
in in
if useConnectionStringFile then '' if useConnectionStringFile then ''
# Read database connection string from file # Read database connection string from file
@@ -155,10 +215,10 @@ in
export NIXOS_OPTIONS_DATABASE="$(cat "${cfg.database.connectionStringFile}")" export NIXOS_OPTIONS_DATABASE="$(cat "${cfg.database.connectionStringFile}")"
${indexCommands} ${indexCommands}
exec nixos-options serve exec nixos-options serve ${httpFlags}
'' else '' '' else ''
${indexCommands} ${indexCommands}
exec nixos-options serve exec nixos-options serve ${httpFlags}
''; '';
serviceConfig = { serviceConfig = {
@@ -188,5 +248,14 @@ in
StateDirectory = lib.mkIf (cfg.dataDir == "/var/lib/nixos-options-mcp") "nixos-options-mcp"; StateDirectory = lib.mkIf (cfg.dataDir == "/var/lib/nixos-options-mcp") "nixos-options-mcp";
}; };
}; };
# Open firewall for HTTP port if configured
networking.firewall = lib.mkIf cfg.openFirewall (let
# Extract port from address (format: "host:port" or ":port")
addressParts = lib.splitString ":" cfg.http.address;
port = lib.toInt (lib.last addressParts);
in {
allowedTCPPorts = [ port ];
});
}; };
} }