Merge pull request 'feature/streamable-http-transport' (#1) from feature/streamable-http-transport into master
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
result
|
result
|
||||||
|
*.db
|
||||||
|
|||||||
269
CLAUDE.md
269
CLAUDE.md
@@ -12,176 +12,157 @@ The first MCP server provides search and query capabilities for NixOS configurat
|
|||||||
|
|
||||||
## Technology Stack
|
## Technology Stack
|
||||||
|
|
||||||
- **Language**: Go 1.25.5
|
- **Language**: Go 1.24+
|
||||||
- **Build System**: Nix flakes
|
- **Build System**: Nix flakes
|
||||||
- **Databases**: PostgreSQL (primary) and SQLite (lightweight alternative)
|
- **Databases**: PostgreSQL and SQLite (both fully supported)
|
||||||
- **Protocol**: MCP (Model Context Protocol) - JSON-RPC over stdio
|
- **Protocol**: MCP (Model Context Protocol) - JSON-RPC over STDIO or HTTP/SSE
|
||||||
- **Module Path**: `git.t-juice.club/torjus/labmcp`
|
- **Module Path**: `git.t-juice.club/torjus/labmcp`
|
||||||
|
|
||||||
## Key Architectural Decisions
|
## Project Status
|
||||||
|
|
||||||
1. **Database Support**: Both PostgreSQL and SQLite
|
**Complete and maintained** - All core features implemented:
|
||||||
- PostgreSQL is preferred for production use (user's preference)
|
- Full MCP server with 6 tools
|
||||||
- SQLite provides lightweight alternative for simpler deployments
|
- PostgreSQL and SQLite backends with FTS
|
||||||
- Use Go's `database/sql` interface for abstraction
|
- NixOS module for deployment
|
||||||
|
- CLI for manual operations
|
||||||
|
- Comprehensive test suite
|
||||||
|
|
||||||
2. **File Storage**: Store nixpkgs file contents in database during indexing
|
## Repository Structure
|
||||||
- Better performance for the `get_file` tool
|
|
||||||
- PostgreSQL handles large text storage well
|
|
||||||
|
|
||||||
3. **Revision Management**: Support multiple indexed nixpkgs revisions
|
|
||||||
- Store git hash, date, channel name, option count
|
|
||||||
- Allow querying specific revisions or use defaults
|
|
||||||
- Default revision: nixos-stable (configurable)
|
|
||||||
|
|
||||||
4. **Indexing Approach**: Part of MCP server, blocking operation (initially)
|
|
||||||
- Allows Claude to read flake.lock and request indexing
|
|
||||||
- Can optimize to async later if needed
|
|
||||||
|
|
||||||
5. **Testing**: Aim for >80% test coverage
|
|
||||||
- Unit tests for all components
|
|
||||||
- Integration tests for full workflows
|
|
||||||
- Benchmarks for indexing and query performance
|
|
||||||
|
|
||||||
## MCP Tools to Implement
|
|
||||||
|
|
||||||
### Core Search & Query
|
|
||||||
1. **`search_options`** - Fuzzy/partial matching search
|
|
||||||
- Parameters: revision, query, optional filters (type, namespace, hasDefault)
|
|
||||||
- Returns: matching options with basic metadata
|
|
||||||
|
|
||||||
2. **`get_option`** - Get full details for specific option
|
|
||||||
- Parameters: revision, option_path, optional depth
|
|
||||||
- Returns: name, type, default, example, description, file paths
|
|
||||||
- Default: direct children only (one level deep)
|
|
||||||
- Includes related/nearby options in same namespace
|
|
||||||
|
|
||||||
3. **`get_file`** - Fetch nixpkgs source file contents
|
|
||||||
- Parameters: revision, file_path
|
|
||||||
- Returns: file contents
|
|
||||||
- Security: validate paths, no traversal, nixpkgs-only
|
|
||||||
|
|
||||||
### Revision Management
|
|
||||||
4. **`index_revision`** - Index a specific nixpkgs revision
|
|
||||||
- Parameters: git_hash (full or short)
|
|
||||||
- Process: fetch nixpkgs, extract options.json, populate DB
|
|
||||||
- Returns: summary (option count, duration, etc.)
|
|
||||||
|
|
||||||
5. **`list_revisions`** - List indexed revisions
|
|
||||||
- Returns: git hash, date, channel name, option count
|
|
||||||
|
|
||||||
6. **`delete_revision`** - Prune old/unused revisions
|
|
||||||
- Parameters: revision identifier
|
|
||||||
- Returns: confirmation of deletion
|
|
||||||
|
|
||||||
### Channel Support
|
|
||||||
- Support friendly aliases: `nixos-unstable`, `nixos-24.05`, `nixos-23.11`, etc.
|
|
||||||
- Can be used in place of git hashes in all tools
|
|
||||||
|
|
||||||
## Database Schema
|
|
||||||
|
|
||||||
**Tables:**
|
|
||||||
|
|
||||||
1. `revisions` - Indexed nixpkgs versions
|
|
||||||
- id, git_hash (unique), channel_name, commit_date, indexed_at, option_count
|
|
||||||
|
|
||||||
2. `options` - NixOS options with hierarchy support
|
|
||||||
- id, revision_id (FK), name, parent_path, type, default_value (JSON text), example (JSON text), description, read_only
|
|
||||||
- parent_path enables efficient "list children" queries (derived from name)
|
|
||||||
|
|
||||||
3. `declarations` - File paths where options are declared
|
|
||||||
- id, option_id (FK), file_path, line_number
|
|
||||||
|
|
||||||
4. `files` - Cached file contents
|
|
||||||
- id, revision_id (FK), file_path, extension, content
|
|
||||||
- Configurable whitelist of extensions (default: .nix, .json, .md, .txt, .toml, .yaml, .yml)
|
|
||||||
|
|
||||||
**Indexes:**
|
|
||||||
- Full-text search: PostgreSQL (tsvector/GIN), SQLite (FTS5)
|
|
||||||
- B-tree on (revision_id, name) and (revision_id, parent_path)
|
|
||||||
- B-tree on (revision_id, file_path) for file lookups
|
|
||||||
|
|
||||||
**Cross-DB Compatibility:**
|
|
||||||
- JSON stored as TEXT (not JSONB) for SQLite compatibility
|
|
||||||
- Separate FTS implementations per database engine
|
|
||||||
|
|
||||||
## Repository Structure (Planned)
|
|
||||||
|
|
||||||
```
|
```
|
||||||
labmcp/
|
labmcp/
|
||||||
├── cmd/
|
├── cmd/
|
||||||
│ └── nixos-options/ # MCP server binary
|
│ └── nixos-options/
|
||||||
│ └── main.go
|
│ └── main.go # CLI entry point
|
||||||
├── internal/
|
├── internal/
|
||||||
│ ├── mcp/ # MCP protocol implementation
|
│ ├── database/
|
||||||
│ │ ├── server.go
|
│ │ ├── interface.go # Store interface
|
||||||
│ │ └── types.go
|
│ │ ├── schema.go # Schema versioning
|
||||||
│ ├── database/ # Database abstraction
|
│ │ ├── postgres.go # PostgreSQL implementation
|
||||||
│ │ ├── interface.go
|
│ │ ├── sqlite.go # SQLite implementation
|
||||||
│ │ ├── postgres.go
|
│ │ └── *_test.go # Database tests
|
||||||
│ │ └── sqlite.go
|
│ ├── mcp/
|
||||||
│ └── nixos/ # NixOS options specific logic
|
│ │ ├── server.go # MCP server core
|
||||||
│ ├── search.go
|
│ │ ├── handlers.go # Tool implementations
|
||||||
│ └── types.go
|
│ │ ├── types.go # Protocol types
|
||||||
├── scripts/
|
│ │ ├── transport.go # Transport interface
|
||||||
│ └── populate-db.go # Tool to populate database
|
│ │ ├── transport_stdio.go # STDIO transport
|
||||||
├── schema/
|
│ │ ├── transport_http.go # HTTP/SSE transport
|
||||||
│ └── schema.sql # Database schema
|
│ │ ├── session.go # HTTP session management
|
||||||
├── flake.nix # Nix build configuration
|
│ │ └── *_test.go # MCP tests
|
||||||
|
│ └── nixos/
|
||||||
|
│ ├── indexer.go # Nixpkgs indexing
|
||||||
|
│ ├── parser.go # options.json parsing
|
||||||
|
│ ├── types.go # Channel aliases, extensions
|
||||||
|
│ └── *_test.go # Indexer tests
|
||||||
|
├── nix/
|
||||||
|
│ ├── module.nix # NixOS module
|
||||||
|
│ └── package.nix # Nix package definition
|
||||||
|
├── testdata/
|
||||||
|
│ └── options-sample.json # Test fixture
|
||||||
|
├── flake.nix
|
||||||
├── go.mod
|
├── go.mod
|
||||||
├── TODO.md # Detailed task list
|
├── .mcp.json # MCP client configuration
|
||||||
├── CLAUDE.md # This file
|
├── CLAUDE.md # This file
|
||||||
└── README.md
|
├── README.md
|
||||||
|
└── TODO.md # Future improvements
|
||||||
```
|
```
|
||||||
|
|
||||||
## Use Cases
|
## MCP Tools
|
||||||
|
|
||||||
**Primary Use Case**: Claude can help users find and understand NixOS options
|
All tools are implemented and functional:
|
||||||
- "What options are available for nginx?"
|
|
||||||
- "Show me the services.caddy.* options"
|
|
||||||
- "What's the default value for services.postgresql.enable?"
|
|
||||||
- User shares a flake.lock → Claude indexes that nixpkgs version → answers questions about options in that specific version
|
|
||||||
|
|
||||||
**Secondary Use Case**: Explore module implementations
|
| Tool | Description |
|
||||||
- If option description is unclear, fetch the actual module source
|
|------|-------------|
|
||||||
- Understand how complex options are structured
|
| `search_options` | Full-text search across option names and descriptions |
|
||||||
|
| `get_option` | Get full details for a specific option with children |
|
||||||
|
| `get_file` | Fetch source file contents from indexed nixpkgs |
|
||||||
|
| `index_revision` | Index a nixpkgs revision (by hash or channel name) |
|
||||||
|
| `list_revisions` | List all indexed revisions |
|
||||||
|
| `delete_revision` | Delete an indexed revision |
|
||||||
|
|
||||||
## Testing Strategy
|
## Key Implementation Details
|
||||||
|
|
||||||
- **Unit Tests**: All components with mocks where appropriate
|
### Database
|
||||||
- **Integration Tests**: Full indexing pipeline, MCP tool invocations
|
- Schema versioning with automatic recreation on version mismatch
|
||||||
- **Benchmarks**: Indexing time, query performance, memory usage
|
- Full-text search: SQLite FTS5, PostgreSQL tsvector/GIN
|
||||||
- **Test Fixtures**: Sample options.json, mock repositories
|
- Path-based queries use LIKE for exact prefix matching
|
||||||
- **Coverage Goal**: >80% on core logic, 100% on database operations
|
- Batch operations for efficient indexing
|
||||||
|
|
||||||
## Open Questions
|
### Indexing
|
||||||
|
- Uses `nix-build` to evaluate NixOS options from any nixpkgs revision
|
||||||
|
- File indexing downloads tarball and stores allowed extensions (.nix, .json, .md, etc.)
|
||||||
|
- File indexing enabled by default (use `--no-files` to skip)
|
||||||
|
- Skips already-indexed revisions (use `--force` to re-index)
|
||||||
|
|
||||||
1. Should `index_revision` be blocking or async? (Currently: blocking, optimize later)
|
### Transports
|
||||||
2. Should we auto-update channel aliases or manual only?
|
- **STDIO**: Default transport, line-delimited JSON-RPC (for CLI/desktop MCP clients)
|
||||||
|
- **HTTP**: Streamable HTTP transport with SSE (for web-based MCP clients)
|
||||||
|
- Session management with cryptographically secure IDs
|
||||||
|
- Configurable CORS (localhost-only by default)
|
||||||
|
- Optional TLS support
|
||||||
|
- SSE keepalive messages (15s default)
|
||||||
|
|
||||||
## Current Status
|
### Security
|
||||||
|
- Revision parameter validated against strict regex to prevent Nix injection
|
||||||
|
- Path traversal protection using `filepath.Clean()` and `filepath.IsAbs()`
|
||||||
|
- NixOS module supports `connectionStringFile` for PostgreSQL secrets
|
||||||
|
- Systemd service runs with extensive hardening options
|
||||||
|
- HTTP transport hardening:
|
||||||
|
- Request body size limit (1MB default)
|
||||||
|
- Server timeouts (read: 30s, write: 30s, idle: 120s, header: 10s)
|
||||||
|
- Maximum session limit (10,000 default)
|
||||||
|
- Origin validation for CORS
|
||||||
|
|
||||||
**Planning phase** - architecture and features defined, ready to begin implementation.
|
## CLI Commands
|
||||||
|
|
||||||
## Next Steps
|
```bash
|
||||||
|
nixos-options serve # Run MCP server on STDIO (default)
|
||||||
1. Design and implement database schema
|
nixos-options serve --transport http # Run MCP server on HTTP
|
||||||
2. Set up project structure (directories, Go modules)
|
nixos-options serve --transport http \
|
||||||
3. Implement database abstraction layer
|
--http-address 0.0.0.0:8080 \
|
||||||
4. Implement MCP protocol basics
|
--allowed-origins https://example.com # HTTP with custom config
|
||||||
5. Build indexing logic
|
nixos-options index <revision> # Index a nixpkgs revision
|
||||||
6. Implement MCP tools
|
nixos-options index --force <r> # Force re-index existing revision
|
||||||
7. Create Nix package in flake.nix
|
nixos-options index --no-files # Skip file content indexing
|
||||||
8. Write tests and benchmarks
|
nixos-options list # List indexed revisions
|
||||||
|
nixos-options search <query> # Search options
|
||||||
|
nixos-options get <option> # Get option details
|
||||||
|
nixos-options delete <revision> # Delete indexed revision
|
||||||
|
nixos-options --version # Show version
|
||||||
|
```
|
||||||
|
|
||||||
## Notes for Claude
|
## Notes for Claude
|
||||||
|
|
||||||
|
### Development Workflow
|
||||||
|
- **Always run `go fmt ./...` before committing Go code**
|
||||||
|
- **Run Go commands using `nix develop -c`** (e.g., `nix develop -c go test ./...`)
|
||||||
|
- **Use `nix run` to run binaries** (e.g., `nix run .#nixos-options -- serve`)
|
||||||
|
- File paths in responses should use format `path/to/file.go:123`
|
||||||
|
|
||||||
|
### User Preferences
|
||||||
- User prefers PostgreSQL over SQLite (has homelab infrastructure)
|
- User prefers PostgreSQL over SQLite (has homelab infrastructure)
|
||||||
- User values good test coverage and benchmarking
|
- User values good test coverage and benchmarking
|
||||||
- Project should remain generic to support future MCP servers
|
- Project should remain generic to support future MCP servers
|
||||||
- Nix flake must provide importable packages for other repos
|
|
||||||
- Use `database/sql` interface for database abstraction
|
### Testing
|
||||||
- File paths in responses should use format `path/to/file.go:123`
|
```bash
|
||||||
- **Always run `go fmt ./...` before committing Go code**
|
# Run all tests
|
||||||
- **Run Go commands using `nix develop -c`** (e.g., `nix develop -c go test ./...`) to ensure proper build environment with all dependencies
|
nix develop -c go test ./... -short
|
||||||
- **Use `nix run` to run binaries** instead of `go build` followed by running the binary (e.g., `nix run .#nixos-options -- serve`)
|
|
||||||
|
# Run with verbose output
|
||||||
|
nix develop -c go test ./... -v
|
||||||
|
|
||||||
|
# Run benchmarks (requires nix-build)
|
||||||
|
nix develop -c go test -bench=. -benchtime=1x -timeout=30m ./internal/nixos/...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Building
|
||||||
|
```bash
|
||||||
|
# Build with nix
|
||||||
|
nix build
|
||||||
|
|
||||||
|
# Run directly
|
||||||
|
nix run . -- serve
|
||||||
|
nix run . -- index nixos-unstable
|
||||||
|
```
|
||||||
|
|||||||
56
README.md
56
README.md
@@ -20,10 +20,10 @@ Search and query NixOS configuration options across multiple nixpkgs revisions.
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Build the package
|
# Build the package
|
||||||
nix build github:torjus/labmcp
|
nix build git+https://git.t-juice.club/torjus/labmcp
|
||||||
|
|
||||||
# Or run directly
|
# Or run directly
|
||||||
nix run github:torjus/labmcp -- --help
|
nix run git+https://git.t-juice.club/torjus/labmcp -- --help
|
||||||
```
|
```
|
||||||
|
|
||||||
### From Source
|
### From Source
|
||||||
@@ -34,7 +34,7 @@ go install git.t-juice.club/torjus/labmcp/cmd/nixos-options@latest
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### As MCP Server
|
### As MCP Server (STDIO)
|
||||||
|
|
||||||
Configure in your MCP client (e.g., Claude Desktop):
|
Configure in your MCP client (e.g., Claude Desktop):
|
||||||
|
|
||||||
@@ -52,12 +52,52 @@ Configure in your MCP client (e.g., Claude Desktop):
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Alternatively, if you have Nix installed, you can use the flake directly without installing the package:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"nixos-options": {
|
||||||
|
"command": "nix",
|
||||||
|
"args": ["run", "git+https://git.t-juice.club/torjus/labmcp", "--", "serve"],
|
||||||
|
"env": {
|
||||||
|
"NIXOS_OPTIONS_DATABASE": "sqlite:///path/to/nixos-options.db"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
Then start the server:
|
Then start the server:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
nixos-options serve
|
nixos-options serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### As MCP Server (HTTP)
|
||||||
|
|
||||||
|
The server can also run over HTTP with Server-Sent Events (SSE) for web-based MCP clients:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start HTTP server on default address (127.0.0.1:8080)
|
||||||
|
nixos-options serve --transport http
|
||||||
|
|
||||||
|
# Custom address and CORS configuration
|
||||||
|
nixos-options serve --transport http \
|
||||||
|
--http-address 0.0.0.0:8080 \
|
||||||
|
--allowed-origins https://example.com
|
||||||
|
|
||||||
|
# With TLS
|
||||||
|
nixos-options serve --transport http \
|
||||||
|
--tls-cert /path/to/cert.pem \
|
||||||
|
--tls-key /path/to/key.pem
|
||||||
|
```
|
||||||
|
|
||||||
|
HTTP transport endpoints:
|
||||||
|
- `POST /mcp` - JSON-RPC requests (returns `Mcp-Session-Id` header on initialize)
|
||||||
|
- `GET /mcp` - SSE stream for server notifications (requires `Mcp-Session-Id` header)
|
||||||
|
- `DELETE /mcp` - Terminate session
|
||||||
|
|
||||||
### CLI Examples
|
### CLI Examples
|
||||||
|
|
||||||
**Index a nixpkgs revision:**
|
**Index a nixpkgs revision:**
|
||||||
@@ -154,7 +194,7 @@ A NixOS module is provided for running the MCP server as a systemd service.
|
|||||||
|
|
||||||
```nix
|
```nix
|
||||||
{
|
{
|
||||||
inputs.labmcp.url = "github:torjus/labmcp";
|
inputs.labmcp.url = "git+https://git.t-juice.club/torjus/labmcp";
|
||||||
|
|
||||||
outputs = { self, nixpkgs, labmcp }: {
|
outputs = { self, nixpkgs, labmcp }: {
|
||||||
nixosConfigurations.myhost = nixpkgs.lib.nixosSystem {
|
nixosConfigurations.myhost = nixpkgs.lib.nixosSystem {
|
||||||
@@ -187,6 +227,14 @@ A NixOS module is provided for running the MCP server as a systemd service.
|
|||||||
| `user` | string | `"nixos-options-mcp"` | User to run the service as |
|
| `user` | string | `"nixos-options-mcp"` | User to run the service as |
|
||||||
| `group` | string | `"nixos-options-mcp"` | Group to run the service as |
|
| `group` | string | `"nixos-options-mcp"` | Group to run the service as |
|
||||||
| `dataDir` | path | `/var/lib/nixos-options-mcp` | Directory for data storage |
|
| `dataDir` | path | `/var/lib/nixos-options-mcp` | Directory for data storage |
|
||||||
|
| `http.address` | string | `"127.0.0.1:8080"` | HTTP listen address |
|
||||||
|
| `http.endpoint` | string | `"/mcp"` | HTTP endpoint path |
|
||||||
|
| `http.allowedOrigins` | list of string | `[]` | Allowed CORS origins (empty = localhost only) |
|
||||||
|
| `http.sessionTTL` | string | `"30m"` | Session timeout (Go duration format) |
|
||||||
|
| `http.tls.enable` | bool | `false` | Enable TLS |
|
||||||
|
| `http.tls.certFile` | path | `null` | TLS certificate file |
|
||||||
|
| `http.tls.keyFile` | path | `null` | TLS private key file |
|
||||||
|
| `openFirewall` | bool | `false` | Open firewall for HTTP port |
|
||||||
|
|
||||||
### PostgreSQL Example
|
### PostgreSQL Example
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
@@ -36,7 +39,42 @@ func main() {
|
|||||||
Commands: []*cli.Command{
|
Commands: []*cli.Command{
|
||||||
{
|
{
|
||||||
Name: "serve",
|
Name: "serve",
|
||||||
Usage: "Run MCP server (stdio)",
|
Usage: "Run MCP server",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "transport",
|
||||||
|
Aliases: []string{"t"},
|
||||||
|
Usage: "Transport type: 'stdio' or 'http'",
|
||||||
|
Value: "stdio",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "http-address",
|
||||||
|
Usage: "HTTP listen address",
|
||||||
|
Value: "127.0.0.1:8080",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "http-endpoint",
|
||||||
|
Usage: "HTTP endpoint path",
|
||||||
|
Value: "/mcp",
|
||||||
|
},
|
||||||
|
&cli.StringSliceFlag{
|
||||||
|
Name: "allowed-origins",
|
||||||
|
Usage: "Allowed Origin headers for CORS (can be specified multiple times)",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "tls-cert",
|
||||||
|
Usage: "TLS certificate file",
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "tls-key",
|
||||||
|
Usage: "TLS key file",
|
||||||
|
},
|
||||||
|
&cli.DurationFlag{
|
||||||
|
Name: "session-ttl",
|
||||||
|
Usage: "Session TTL for HTTP transport",
|
||||||
|
Value: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
},
|
||||||
Action: func(c *cli.Context) error {
|
Action: func(c *cli.Context) error {
|
||||||
return runServe(c)
|
return runServe(c)
|
||||||
},
|
},
|
||||||
@@ -145,7 +183,8 @@ func openStore(connStr string) (database.Store, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func runServe(c *cli.Context) error {
|
func runServe(c *cli.Context) error {
|
||||||
ctx := context.Background()
|
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
store, err := openStore(c.String("database"))
|
store, err := openStore(c.String("database"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -163,8 +202,27 @@ func runServe(c *cli.Context) error {
|
|||||||
indexer := nixos.NewIndexer(store)
|
indexer := nixos.NewIndexer(store)
|
||||||
server.RegisterHandlers(indexer)
|
server.RegisterHandlers(indexer)
|
||||||
|
|
||||||
|
transport := c.String("transport")
|
||||||
|
switch transport {
|
||||||
|
case "stdio":
|
||||||
logger.Println("Starting MCP server on stdio...")
|
logger.Println("Starting MCP server on stdio...")
|
||||||
return server.Run(ctx, os.Stdin, os.Stdout)
|
return server.Run(ctx, os.Stdin, os.Stdout)
|
||||||
|
|
||||||
|
case "http":
|
||||||
|
config := mcp.HTTPConfig{
|
||||||
|
Address: c.String("http-address"),
|
||||||
|
Endpoint: c.String("http-endpoint"),
|
||||||
|
AllowedOrigins: c.StringSlice("allowed-origins"),
|
||||||
|
SessionTTL: c.Duration("session-ttl"),
|
||||||
|
TLSCertFile: c.String("tls-cert"),
|
||||||
|
TLSKeyFile: c.String("tls-key"),
|
||||||
|
}
|
||||||
|
httpTransport := mcp.NewHTTPTransport(server, config)
|
||||||
|
return httpTransport.Run(ctx)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown transport: %s (use 'stdio' or 'http')", transport)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runIndex(c *cli.Context, revision string, indexFiles bool, force bool) error {
|
func runIndex(c *cli.Context, revision string, indexFiles bool, force bool) error {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -11,7 +10,7 @@ import (
|
|||||||
"git.t-juice.club/torjus/labmcp/internal/database"
|
"git.t-juice.club/torjus/labmcp/internal/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server is an MCP server that handles JSON-RPC requests over stdio.
|
// Server is an MCP server that handles JSON-RPC requests.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
store database.Store
|
store database.Store
|
||||||
tools map[string]ToolHandler
|
tools map[string]ToolHandler
|
||||||
@@ -41,53 +40,34 @@ func (s *Server) registerTools() {
|
|||||||
// Tools will be implemented in handlers.go
|
// Tools will be implemented in handlers.go
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run starts the server, reading from r and writing to w.
|
// Run starts the server using STDIO transport (backward compatibility).
|
||||||
func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error {
|
func (s *Server) Run(ctx context.Context, r io.Reader, w io.Writer) error {
|
||||||
scanner := bufio.NewScanner(r)
|
transport := NewStdioTransport(s, r, w)
|
||||||
encoder := json.NewEncoder(w)
|
return transport.Run(ctx)
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
line := scanner.Bytes()
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandleMessage parses a JSON-RPC message and returns the response.
|
||||||
|
// Returns (nil, nil) for notifications that don't require a response.
|
||||||
|
func (s *Server) HandleMessage(ctx context.Context, data []byte) (*Response, error) {
|
||||||
var req Request
|
var req Request
|
||||||
if err := json.Unmarshal(line, &req); err != nil {
|
if err := json.Unmarshal(data, &req); err != nil {
|
||||||
s.logger.Printf("Failed to parse request: %v", err)
|
return &Response{
|
||||||
resp := Response{
|
|
||||||
JSONRPC: "2.0",
|
JSONRPC: "2.0",
|
||||||
Error: &Error{
|
Error: &Error{
|
||||||
Code: ParseError,
|
Code: ParseError,
|
||||||
Message: "Parse error",
|
Message: "Parse error",
|
||||||
Data: err.Error(),
|
Data: err.Error(),
|
||||||
},
|
},
|
||||||
}
|
}, nil
|
||||||
if err := encoder.Encode(resp); err != nil {
|
|
||||||
return fmt.Errorf("failed to write response: %w", err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := s.handleRequest(ctx, &req)
|
return s.HandleRequest(ctx, &req), nil
|
||||||
if resp != nil {
|
|
||||||
if err := encoder.Encode(resp); err != nil {
|
|
||||||
return fmt.Errorf("failed to write response: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
// HandleRequest processes a single request and returns a response.
|
||||||
return fmt.Errorf("scanner error: %w", err)
|
// Returns nil for notifications that don't require a response.
|
||||||
}
|
func (s *Server) HandleRequest(ctx context.Context, req *Request) *Response {
|
||||||
|
return s.handleRequest(ctx, req)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRequest processes a single request and returns a response.
|
// handleRequest processes a single request and returns a response.
|
||||||
|
|||||||
231
internal/mcp/session.go
Normal file
231
internal/mcp/session.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Session represents an MCP client session.
|
||||||
|
type Session struct {
|
||||||
|
ID string
|
||||||
|
CreatedAt time.Time
|
||||||
|
LastActivity time.Time
|
||||||
|
Initialized bool
|
||||||
|
|
||||||
|
// notifications is a channel for server-initiated notifications.
|
||||||
|
notifications chan *Response
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSession creates a new session with a cryptographically secure random ID.
|
||||||
|
func NewSession() (*Session, error) {
|
||||||
|
id, err := generateSessionID()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
return &Session{
|
||||||
|
ID: id,
|
||||||
|
CreatedAt: now,
|
||||||
|
LastActivity: now,
|
||||||
|
notifications: make(chan *Response, 100),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Touch updates the session's last activity time.
|
||||||
|
func (s *Session) Touch() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.LastActivity = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetInitialized marks the session as initialized.
|
||||||
|
func (s *Session) SetInitialized() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.Initialized = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInitialized returns whether the session has been initialized.
|
||||||
|
func (s *Session) IsInitialized() bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.Initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notifications returns the channel for server-initiated notifications.
|
||||||
|
func (s *Session) Notifications() <-chan *Response {
|
||||||
|
return s.notifications
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendNotification sends a notification to the session.
|
||||||
|
// Returns false if the channel is full.
|
||||||
|
func (s *Session) SendNotification(notification *Response) bool {
|
||||||
|
select {
|
||||||
|
case s.notifications <- notification:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the session's notification channel.
|
||||||
|
func (s *Session) Close() {
|
||||||
|
close(s.notifications)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStore manages active sessions with TTL-based cleanup.
|
||||||
|
type SessionStore struct {
|
||||||
|
sessions map[string]*Session
|
||||||
|
ttl time.Duration
|
||||||
|
maxSessions int
|
||||||
|
mu sync.RWMutex
|
||||||
|
stopClean chan struct{}
|
||||||
|
cleanDone chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrTooManySessions is returned when the session limit is reached.
|
||||||
|
var ErrTooManySessions = fmt.Errorf("too many active sessions")
|
||||||
|
|
||||||
|
// DefaultMaxSessions is the default maximum number of concurrent sessions.
|
||||||
|
const DefaultMaxSessions = 10000
|
||||||
|
|
||||||
|
// NewSessionStore creates a new session store with the given TTL.
|
||||||
|
func NewSessionStore(ttl time.Duration) *SessionStore {
|
||||||
|
return NewSessionStoreWithLimit(ttl, DefaultMaxSessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSessionStoreWithLimit creates a new session store with TTL and max session limit.
|
||||||
|
func NewSessionStoreWithLimit(ttl time.Duration, maxSessions int) *SessionStore {
|
||||||
|
if maxSessions <= 0 {
|
||||||
|
maxSessions = DefaultMaxSessions
|
||||||
|
}
|
||||||
|
s := &SessionStore{
|
||||||
|
sessions: make(map[string]*Session),
|
||||||
|
ttl: ttl,
|
||||||
|
maxSessions: maxSessions,
|
||||||
|
stopClean: make(chan struct{}),
|
||||||
|
cleanDone: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go s.cleanupLoop()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create creates a new session and adds it to the store.
|
||||||
|
// Returns ErrTooManySessions if the maximum session limit is reached.
|
||||||
|
func (s *SessionStore) Create() (*Session, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Check session limit
|
||||||
|
if len(s.sessions) >= s.maxSessions {
|
||||||
|
return nil, ErrTooManySessions
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sessions[session.ID] = session
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a session by ID. Returns nil if not found or expired.
|
||||||
|
func (s *SessionStore) Get(id string) *Session {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
session, ok := s.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if expired
|
||||||
|
session.mu.RLock()
|
||||||
|
expired := time.Since(session.LastActivity) > s.ttl
|
||||||
|
session.mu.RUnlock()
|
||||||
|
|
||||||
|
if expired {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return session
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a session from the store.
|
||||||
|
func (s *SessionStore) Delete(id string) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
session, ok := s.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
session.Close()
|
||||||
|
delete(s.sessions, id)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of active sessions.
|
||||||
|
func (s *SessionStore) Count() int {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return len(s.sessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the cleanup goroutine and waits for it to finish.
|
||||||
|
func (s *SessionStore) Stop() {
|
||||||
|
close(s.stopClean)
|
||||||
|
<-s.cleanDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupLoop periodically removes expired sessions.
|
||||||
|
func (s *SessionStore) cleanupLoop() {
|
||||||
|
defer close(s.cleanDone)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(s.ttl / 2)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.stopClean:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes expired sessions.
|
||||||
|
func (s *SessionStore) cleanup() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for id, session := range s.sessions {
|
||||||
|
session.mu.RLock()
|
||||||
|
expired := now.Sub(session.LastActivity) > s.ttl
|
||||||
|
session.mu.RUnlock()
|
||||||
|
|
||||||
|
if expired {
|
||||||
|
session.Close()
|
||||||
|
delete(s.sessions, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSessionID generates a cryptographically secure random session ID.
|
||||||
|
func generateSessionID() (string, error) {
|
||||||
|
bytes := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
337
internal/mcp/session_test.go
Normal file
337
internal/mcp/session_test.go
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSession(t *testing.T) {
|
||||||
|
session, err := NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.ID == "" {
|
||||||
|
t.Error("Session ID should not be empty")
|
||||||
|
}
|
||||||
|
if len(session.ID) != 32 {
|
||||||
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
|
||||||
|
}
|
||||||
|
if session.Initialized {
|
||||||
|
t.Error("New session should not be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionTouch(t *testing.T) {
|
||||||
|
session, _ := NewSession()
|
||||||
|
originalActivity := session.LastActivity
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
session.Touch()
|
||||||
|
|
||||||
|
if !session.LastActivity.After(originalActivity) {
|
||||||
|
t.Error("Touch should update LastActivity")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionInitialized(t *testing.T) {
|
||||||
|
session, _ := NewSession()
|
||||||
|
|
||||||
|
if session.IsInitialized() {
|
||||||
|
t.Error("New session should not be initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
session.SetInitialized()
|
||||||
|
|
||||||
|
if !session.IsInitialized() {
|
||||||
|
t.Error("Session should be initialized after SetInitialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionNotifications(t *testing.T) {
|
||||||
|
session, _ := NewSession()
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
notification := &Response{JSONRPC: "2.0", ID: 1}
|
||||||
|
|
||||||
|
if !session.SendNotification(notification) {
|
||||||
|
t.Error("SendNotification should return true on success")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case received := <-session.Notifications():
|
||||||
|
if received.ID != notification.ID {
|
||||||
|
t.Error("Received notification should match sent")
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Error("Should receive notification")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreCreate(t *testing.T) {
|
||||||
|
store := NewSessionStore(30 * time.Minute)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
session, err := store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != 1 {
|
||||||
|
t.Errorf("Store should have 1 session, got %d", store.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we can retrieve it
|
||||||
|
retrieved := store.Get(session.ID)
|
||||||
|
if retrieved == nil {
|
||||||
|
t.Error("Should be able to retrieve created session")
|
||||||
|
}
|
||||||
|
if retrieved.ID != session.ID {
|
||||||
|
t.Error("Retrieved session ID should match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreGet(t *testing.T) {
|
||||||
|
store := NewSessionStore(30 * time.Minute)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Get non-existent session
|
||||||
|
if store.Get("nonexistent") != nil {
|
||||||
|
t.Error("Should return nil for non-existent session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and retrieve
|
||||||
|
session, _ := store.Create()
|
||||||
|
retrieved := store.Get(session.ID)
|
||||||
|
if retrieved == nil {
|
||||||
|
t.Error("Should find created session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreDelete(t *testing.T) {
|
||||||
|
store := NewSessionStore(30 * time.Minute)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
session, _ := store.Create()
|
||||||
|
if store.Count() != 1 {
|
||||||
|
t.Error("Should have 1 session after create")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !store.Delete(session.ID) {
|
||||||
|
t.Error("Delete should return true for existing session")
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != 0 {
|
||||||
|
t.Error("Should have 0 sessions after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Delete(session.ID) {
|
||||||
|
t.Error("Delete should return false for non-existent session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreTTLExpiration(t *testing.T) {
|
||||||
|
ttl := 50 * time.Millisecond
|
||||||
|
store := NewSessionStore(ttl)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
session, _ := store.Create()
|
||||||
|
|
||||||
|
// Should be retrievable immediately
|
||||||
|
if store.Get(session.ID) == nil {
|
||||||
|
t.Error("Session should be retrievable immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for expiration
|
||||||
|
time.Sleep(ttl + 10*time.Millisecond)
|
||||||
|
|
||||||
|
// Should not be retrievable after TTL
|
||||||
|
if store.Get(session.ID) != nil {
|
||||||
|
t.Error("Expired session should not be retrievable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreTTLRefresh(t *testing.T) {
|
||||||
|
ttl := 100 * time.Millisecond
|
||||||
|
store := NewSessionStore(ttl)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
session, _ := store.Create()
|
||||||
|
|
||||||
|
// Touch the session before TTL expires
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
session.Touch()
|
||||||
|
|
||||||
|
// Wait past original TTL but not past refreshed TTL
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should still be retrievable because we touched it
|
||||||
|
if store.Get(session.ID) == nil {
|
||||||
|
t.Error("Touched session should still be retrievable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreCleanup(t *testing.T) {
|
||||||
|
ttl := 50 * time.Millisecond
|
||||||
|
store := NewSessionStore(ttl)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Create multiple sessions
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
store.Create()
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != 5 {
|
||||||
|
t.Errorf("Should have 5 sessions, got %d", store.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for cleanup to run (runs at ttl/2 intervals)
|
||||||
|
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
|
||||||
|
|
||||||
|
// All sessions should be cleaned up
|
||||||
|
if store.Count() != 0 {
|
||||||
|
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreConcurrency(t *testing.T) {
|
||||||
|
store := NewSessionStore(30 * time.Minute)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
sessionIDs := make(chan string, 100)
|
||||||
|
|
||||||
|
// Create sessions concurrently
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
session, err := store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to create session: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sessionIDs <- session.ID
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(sessionIDs)
|
||||||
|
|
||||||
|
// Verify all sessions were created
|
||||||
|
if store.Count() != 50 {
|
||||||
|
t.Errorf("Should have 50 sessions, got %d", store.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and delete concurrently
|
||||||
|
var ids []string
|
||||||
|
for id := range sessionIDs {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
wg.Add(2)
|
||||||
|
go func(id string) {
|
||||||
|
defer wg.Done()
|
||||||
|
store.Get(id)
|
||||||
|
}(id)
|
||||||
|
go func(id string) {
|
||||||
|
defer wg.Done()
|
||||||
|
store.Delete(id)
|
||||||
|
}(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreMaxSessions(t *testing.T) {
|
||||||
|
maxSessions := 5
|
||||||
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Create sessions up to limit
|
||||||
|
for i := 0; i < maxSessions; i++ {
|
||||||
|
_, err := store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create session %d: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if store.Count() != maxSessions {
|
||||||
|
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to create one more - should fail
|
||||||
|
_, err := store.Create()
|
||||||
|
if err != ErrTooManySessions {
|
||||||
|
t.Errorf("Expected ErrTooManySessions, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count should still be at max
|
||||||
|
if store.Count() != maxSessions {
|
||||||
|
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
|
||||||
|
maxSessions := 3
|
||||||
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Fill up the store
|
||||||
|
sessions := make([]*Session, maxSessions)
|
||||||
|
for i := 0; i < maxSessions; i++ {
|
||||||
|
s, err := store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create session: %v", err)
|
||||||
|
}
|
||||||
|
sessions[i] = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be full
|
||||||
|
_, err := store.Create()
|
||||||
|
if err != ErrTooManySessions {
|
||||||
|
t.Error("Expected ErrTooManySessions when full")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete one session
|
||||||
|
store.Delete(sessions[0].ID)
|
||||||
|
|
||||||
|
// Should be able to create again
|
||||||
|
_, err = store.Create()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should be able to create after deletion: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
|
||||||
|
store := NewSessionStore(30 * time.Minute)
|
||||||
|
defer store.Stop()
|
||||||
|
|
||||||
|
// Just verify it uses the default (don't create 10000 sessions)
|
||||||
|
if store.maxSessions != DefaultMaxSessions {
|
||||||
|
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSessionID(t *testing.T) {
|
||||||
|
ids := make(map[string]bool)
|
||||||
|
|
||||||
|
// Generate 1000 IDs and ensure uniqueness
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
id, err := generateSessionID()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate session ID: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(id) != 32 {
|
||||||
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ids[id] {
|
||||||
|
t.Error("Generated duplicate session ID")
|
||||||
|
}
|
||||||
|
ids[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
10
internal/mcp/transport.go
Normal file
10
internal/mcp/transport.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Transport defines the interface for MCP server transports.
|
||||||
|
type Transport interface {
|
||||||
|
// Run starts the transport and blocks until the context is cancelled
|
||||||
|
// or an error occurs.
|
||||||
|
Run(ctx context.Context) error
|
||||||
|
}
|
||||||
448
internal/mcp/transport_http.go
Normal file
448
internal/mcp/transport_http.go
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPConfig configures the HTTP transport.
|
||||||
|
type HTTPConfig struct {
|
||||||
|
Address string // Listen address (e.g., "127.0.0.1:8080")
|
||||||
|
Endpoint string // MCP endpoint path (e.g., "/mcp")
|
||||||
|
AllowedOrigins []string // Allowed Origin headers for CORS (empty = localhost only)
|
||||||
|
SessionTTL time.Duration // Session TTL (default: 30 minutes)
|
||||||
|
MaxSessions int // Maximum concurrent sessions (default: 10000)
|
||||||
|
TLSCertFile string // TLS certificate file (optional)
|
||||||
|
TLSKeyFile string // TLS key file (optional)
|
||||||
|
MaxRequestSize int64 // Maximum request body size in bytes (default: 1MB)
|
||||||
|
ReadTimeout time.Duration // HTTP server read timeout (default: 30s)
|
||||||
|
WriteTimeout time.Duration // HTTP server write timeout (default: 30s)
|
||||||
|
IdleTimeout time.Duration // HTTP server idle timeout (default: 120s)
|
||||||
|
ReadHeaderTimeout time.Duration // HTTP server read header timeout (default: 10s)
|
||||||
|
SSEKeepAlive time.Duration // SSE keepalive interval (default: 15s, 0 to disable)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultMaxRequestSize is the default maximum request body size (1MB).
|
||||||
|
DefaultMaxRequestSize = 1 << 20 // 1MB
|
||||||
|
|
||||||
|
// Default HTTP server timeouts
|
||||||
|
DefaultReadTimeout = 30 * time.Second
|
||||||
|
DefaultWriteTimeout = 30 * time.Second
|
||||||
|
DefaultIdleTimeout = 120 * time.Second
|
||||||
|
DefaultReadHeaderTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// DefaultSSEKeepAlive is the default interval for SSE keepalive messages.
|
||||||
|
// These are sent as SSE comments to keep the connection alive through
|
||||||
|
// proxies and load balancers, and to detect stale connections.
|
||||||
|
DefaultSSEKeepAlive = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPTransport implements the MCP Streamable HTTP transport.
|
||||||
|
type HTTPTransport struct {
|
||||||
|
server *Server
|
||||||
|
config HTTPConfig
|
||||||
|
sessions *SessionStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPTransport creates a new HTTP transport.
|
||||||
|
func NewHTTPTransport(server *Server, config HTTPConfig) *HTTPTransport {
|
||||||
|
if config.Address == "" {
|
||||||
|
config.Address = "127.0.0.1:8080"
|
||||||
|
}
|
||||||
|
if config.Endpoint == "" {
|
||||||
|
config.Endpoint = "/mcp"
|
||||||
|
}
|
||||||
|
if config.SessionTTL == 0 {
|
||||||
|
config.SessionTTL = 30 * time.Minute
|
||||||
|
}
|
||||||
|
if config.MaxSessions == 0 {
|
||||||
|
config.MaxSessions = DefaultMaxSessions
|
||||||
|
}
|
||||||
|
if config.MaxRequestSize == 0 {
|
||||||
|
config.MaxRequestSize = DefaultMaxRequestSize
|
||||||
|
}
|
||||||
|
if config.ReadTimeout == 0 {
|
||||||
|
config.ReadTimeout = DefaultReadTimeout
|
||||||
|
}
|
||||||
|
if config.WriteTimeout == 0 {
|
||||||
|
config.WriteTimeout = DefaultWriteTimeout
|
||||||
|
}
|
||||||
|
if config.IdleTimeout == 0 {
|
||||||
|
config.IdleTimeout = DefaultIdleTimeout
|
||||||
|
}
|
||||||
|
if config.ReadHeaderTimeout == 0 {
|
||||||
|
config.ReadHeaderTimeout = DefaultReadHeaderTimeout
|
||||||
|
}
|
||||||
|
// SSEKeepAlive: 0 means use default, negative means disabled
|
||||||
|
if config.SSEKeepAlive == 0 {
|
||||||
|
config.SSEKeepAlive = DefaultSSEKeepAlive
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HTTPTransport{
|
||||||
|
server: server,
|
||||||
|
config: config,
|
||||||
|
sessions: NewSessionStoreWithLimit(config.SessionTTL, config.MaxSessions),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the HTTP server and blocks until the context is cancelled.
|
||||||
|
func (t *HTTPTransport) Run(ctx context.Context) error {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc(t.config.Endpoint, t.handleMCP)
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: t.config.Address,
|
||||||
|
Handler: mux,
|
||||||
|
ReadTimeout: t.config.ReadTimeout,
|
||||||
|
WriteTimeout: t.config.WriteTimeout,
|
||||||
|
IdleTimeout: t.config.IdleTimeout,
|
||||||
|
ReadHeaderTimeout: t.config.ReadHeaderTimeout,
|
||||||
|
BaseContext: func(l net.Listener) context.Context {
|
||||||
|
return ctx
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown on context cancellation
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
t.server.logger.Println("Shutting down HTTP server...")
|
||||||
|
t.sessions.Stop()
|
||||||
|
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||||
|
t.server.logger.Printf("HTTP server shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.server.logger.Printf("Starting HTTP transport on %s%s", t.config.Address, t.config.Endpoint)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if t.config.TLSCertFile != "" && t.config.TLSKeyFile != "" {
|
||||||
|
err = httpServer.ListenAndServeTLS(t.config.TLSCertFile, t.config.TLSKeyFile)
|
||||||
|
} else {
|
||||||
|
err = httpServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == http.ErrServerClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMCP routes requests based on HTTP method.
|
||||||
|
func (t *HTTPTransport) handleMCP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Validate Origin header
|
||||||
|
if !t.isOriginAllowed(r) {
|
||||||
|
http.Error(w, "Forbidden: Origin not allowed", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
t.handlePost(w, r)
|
||||||
|
case http.MethodGet:
|
||||||
|
t.handleGet(w, r)
|
||||||
|
case http.MethodDelete:
|
||||||
|
t.handleDelete(w, r)
|
||||||
|
case http.MethodOptions:
|
||||||
|
t.handleOptions(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePost handles JSON-RPC requests.
|
||||||
|
func (t *HTTPTransport) handlePost(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Limit request body size to prevent memory exhaustion attacks
|
||||||
|
r.Body = http.MaxBytesReader(w, r.Body, t.config.MaxRequestSize)
|
||||||
|
|
||||||
|
// Read request body
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
// Check if this is a size limit error
|
||||||
|
if err.Error() == "http: request body too large" {
|
||||||
|
http.Error(w, "Request body too large", http.StatusRequestEntityTooLarge)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse request to check method
|
||||||
|
var req Request
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(Response{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Error: &Error{
|
||||||
|
Code: ParseError,
|
||||||
|
Message: "Parse error",
|
||||||
|
Data: err.Error(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle initialize request - create session
|
||||||
|
if req.Method == MethodInitialize {
|
||||||
|
t.handleInitialize(w, r, &req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// All other requests require a valid session
|
||||||
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||||
|
if sessionID == "" {
|
||||||
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := t.sessions.Get(sessionID)
|
||||||
|
if session == nil {
|
||||||
|
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update session activity
|
||||||
|
session.Touch()
|
||||||
|
|
||||||
|
// Handle notifications (no response expected)
|
||||||
|
if req.Method == MethodInitialized {
|
||||||
|
session.SetInitialized()
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a notification (no ID)
|
||||||
|
if req.ID == nil {
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
resp := t.server.HandleRequest(r.Context(), &req)
|
||||||
|
if resp == nil {
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleInitialize handles the initialize request and creates a new session.
|
||||||
|
func (t *HTTPTransport) handleInitialize(w http.ResponseWriter, r *http.Request, req *Request) {
|
||||||
|
// Create a new session
|
||||||
|
session, err := t.sessions.Create()
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrTooManySessions {
|
||||||
|
t.server.logger.Printf("Session limit reached")
|
||||||
|
http.Error(w, "Service unavailable: too many active sessions", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.server.logger.Printf("Failed to create session: %v", err)
|
||||||
|
http.Error(w, "Failed to create session", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process initialize request
|
||||||
|
resp := t.server.HandleRequest(r.Context(), req)
|
||||||
|
if resp == nil {
|
||||||
|
t.sessions.Delete(session.ID)
|
||||||
|
http.Error(w, "Initialize failed: no response", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If initialize failed, clean up session
|
||||||
|
if resp.Error != nil {
|
||||||
|
t.sessions.Delete(session.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Mcp-Session-Id", session.ID)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGet handles SSE stream for server-initiated notifications.
|
||||||
|
func (t *HTTPTransport) handleGet(w http.ResponseWriter, r *http.Request) {
|
||||||
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||||
|
if sessionID == "" {
|
||||||
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
session := t.sessions.Get(sessionID)
|
||||||
|
if session == nil {
|
||||||
|
http.Error(w, "Invalid or expired session", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if client accepts SSE
|
||||||
|
accept := r.Header.Get("Accept")
|
||||||
|
if !strings.Contains(accept, "text/event-stream") {
|
||||||
|
http.Error(w, "Accept header must include text/event-stream", http.StatusNotAcceptable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set SSE headers
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// Flush headers
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Use ResponseController to manage write deadlines for long-lived SSE connections
|
||||||
|
rc := http.NewResponseController(w)
|
||||||
|
|
||||||
|
// Set up keepalive ticker if enabled
|
||||||
|
var keepaliveTicker *time.Ticker
|
||||||
|
var keepaliveChan <-chan time.Time
|
||||||
|
if t.config.SSEKeepAlive > 0 {
|
||||||
|
keepaliveTicker = time.NewTicker(t.config.SSEKeepAlive)
|
||||||
|
keepaliveChan = keepaliveTicker.C
|
||||||
|
defer keepaliveTicker.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream notifications
|
||||||
|
ctx := r.Context()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-keepaliveChan:
|
||||||
|
// Send SSE comment as keepalive (ignored by clients)
|
||||||
|
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||||
|
t.server.logger.Printf("Failed to set write deadline: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, ":keepalive\n\n"); err != nil {
|
||||||
|
// Write failed, connection likely closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
case notification, ok := <-session.Notifications():
|
||||||
|
if !ok {
|
||||||
|
// Session closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend write deadline before each write
|
||||||
|
if err := rc.SetWriteDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||||
|
t.server.logger.Printf("Failed to set write deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(notification)
|
||||||
|
if err != nil {
|
||||||
|
t.server.logger.Printf("Failed to marshal notification: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write SSE event
|
||||||
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
|
||||||
|
// Write failed, connection likely closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
// Touch session to keep it alive
|
||||||
|
session.Touch()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDelete terminates a session.
|
||||||
|
func (t *HTTPTransport) handleDelete(w http.ResponseWriter, r *http.Request) {
|
||||||
|
sessionID := r.Header.Get("Mcp-Session-Id")
|
||||||
|
if sessionID == "" {
|
||||||
|
http.Error(w, "Session ID required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.sessions.Delete(sessionID) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
} else {
|
||||||
|
http.Error(w, "Session not found", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleOptions handles CORS preflight requests.
|
||||||
|
func (t *HTTPTransport) handleOptions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin != "" && t.isOriginAllowed(r) {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
|
||||||
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Accept, Mcp-Session-Id")
|
||||||
|
w.Header().Set("Access-Control-Expose-Headers", "Mcp-Session-Id")
|
||||||
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isOriginAllowed checks if the request origin is allowed.
|
||||||
|
func (t *HTTPTransport) isOriginAllowed(r *http.Request) bool {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
|
||||||
|
// No Origin header (same-origin request) is always allowed
|
||||||
|
if origin == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no allowed origins configured, only allow localhost
|
||||||
|
if len(t.config.AllowedOrigins) == 0 {
|
||||||
|
return isLocalhostOrigin(origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check against allowed origins
|
||||||
|
for _, allowed := range t.config.AllowedOrigins {
|
||||||
|
if allowed == "*" || allowed == origin {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isLocalhostOrigin checks if the origin is a localhost address.
|
||||||
|
func isLocalhostOrigin(origin string) bool {
|
||||||
|
origin = strings.ToLower(origin)
|
||||||
|
|
||||||
|
// Check for localhost patterns (must be followed by :, /, or end of string)
|
||||||
|
localhostPatterns := []string{
|
||||||
|
"http://localhost",
|
||||||
|
"https://localhost",
|
||||||
|
"http://127.0.0.1",
|
||||||
|
"https://127.0.0.1",
|
||||||
|
"http://[::1]",
|
||||||
|
"https://[::1]",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pattern := range localhostPatterns {
|
||||||
|
if origin == pattern {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(origin, pattern+":") || strings.HasPrefix(origin, pattern+"/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
752
internal/mcp/transport_http_test.go
Normal file
752
internal/mcp/transport_http_test.go
Normal file
@@ -0,0 +1,752 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testHTTPTransport creates a transport with a test server
|
||||||
|
func testHTTPTransport(t *testing.T, config HTTPConfig) (*HTTPTransport, *httptest.Server) {
|
||||||
|
// Use a mock store
|
||||||
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
||||||
|
|
||||||
|
if config.SessionTTL == 0 {
|
||||||
|
config.SessionTTL = 30 * time.Minute
|
||||||
|
}
|
||||||
|
transport := NewHTTPTransport(server, config)
|
||||||
|
|
||||||
|
// Create test server
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
endpoint := config.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = "/mcp"
|
||||||
|
}
|
||||||
|
mux.HandleFunc(endpoint, transport.handleMCP)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
ts.Close()
|
||||||
|
transport.sessions.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
return transport, ts
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportInitialize(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
// Send initialize request
|
||||||
|
initReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodInitialize,
|
||||||
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(initReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check session ID header
|
||||||
|
sessionID := resp.Header.Get("Mcp-Session-Id")
|
||||||
|
if sessionID == "" {
|
||||||
|
t.Error("Expected Mcp-Session-Id header")
|
||||||
|
}
|
||||||
|
if len(sessionID) != 32 {
|
||||||
|
t.Errorf("Session ID should be 32 chars, got %d", len(sessionID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check response body
|
||||||
|
var initResp Response
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&initResp); err != nil {
|
||||||
|
t.Fatalf("Failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if initResp.Error != nil {
|
||||||
|
t.Errorf("Initialize failed: %v", initResp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSessionRequired(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
// Send tools/list without session
|
||||||
|
listReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodToolsList,
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(listReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Errorf("Expected 400 without session, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportInvalidSession(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
// Send request with invalid session
|
||||||
|
listReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodToolsList,
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(listReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Mcp-Session-Id", "invalid-session-id")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected 404 for invalid session, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportValidSession(t *testing.T) {
|
||||||
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
// Create session manually
|
||||||
|
session, _ := transport.sessions.Create()
|
||||||
|
|
||||||
|
// Send tools/list with valid session
|
||||||
|
listReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodToolsList,
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(listReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 with valid session, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportNotificationAccepted(t *testing.T) {
|
||||||
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
session, _ := transport.sessions.Create()
|
||||||
|
|
||||||
|
// Send notification (no ID)
|
||||||
|
notification := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Method: MethodInitialized,
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(notification)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusAccepted {
|
||||||
|
t.Errorf("Expected 202 for notification, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session is marked as initialized
|
||||||
|
if !session.IsInitialized() {
|
||||||
|
t.Error("Session should be marked as initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportDeleteSession(t *testing.T) {
|
||||||
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
session, _ := transport.sessions.Create()
|
||||||
|
|
||||||
|
// Delete session
|
||||||
|
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
||||||
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
t.Errorf("Expected 204 for delete, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session is gone
|
||||||
|
if transport.sessions.Get(session.ID) != nil {
|
||||||
|
t.Error("Session should be deleted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportDeleteNonexistentSession(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("DELETE", ts.URL+"/mcp", nil)
|
||||||
|
req.Header.Set("Mcp-Session-Id", "nonexistent")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected 404 for nonexistent session, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportOriginValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
allowedOrigins []string
|
||||||
|
origin string
|
||||||
|
expectAllowed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no origin header",
|
||||||
|
allowedOrigins: nil,
|
||||||
|
origin: "",
|
||||||
|
expectAllowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localhost allowed by default",
|
||||||
|
allowedOrigins: nil,
|
||||||
|
origin: "http://localhost:3000",
|
||||||
|
expectAllowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "127.0.0.1 allowed by default",
|
||||||
|
allowedOrigins: nil,
|
||||||
|
origin: "http://127.0.0.1:8080",
|
||||||
|
expectAllowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "external origin blocked by default",
|
||||||
|
allowedOrigins: nil,
|
||||||
|
origin: "http://evil.com",
|
||||||
|
expectAllowed: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit allow",
|
||||||
|
allowedOrigins: []string{"http://example.com"},
|
||||||
|
origin: "http://example.com",
|
||||||
|
expectAllowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit allow wildcard",
|
||||||
|
allowedOrigins: []string{"*"},
|
||||||
|
origin: "http://anything.com",
|
||||||
|
expectAllowed: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not in allowed list",
|
||||||
|
allowedOrigins: []string{"http://example.com"},
|
||||||
|
origin: "http://other.com",
|
||||||
|
expectAllowed: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
AllowedOrigins: tt.allowedOrigins,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use initialize since it doesn't require a session
|
||||||
|
initReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodInitialize,
|
||||||
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(initReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if tt.origin != "" {
|
||||||
|
req.Header.Set("Origin", tt.origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if tt.expectAllowed && resp.StatusCode == http.StatusForbidden {
|
||||||
|
t.Error("Expected request to be allowed but was forbidden")
|
||||||
|
}
|
||||||
|
if !tt.expectAllowed && resp.StatusCode != http.StatusForbidden {
|
||||||
|
t.Errorf("Expected request to be forbidden but got status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSSERequiresAcceptHeader(t *testing.T) {
|
||||||
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
session, _ := transport.sessions.Create()
|
||||||
|
|
||||||
|
// GET without Accept: text/event-stream
|
||||||
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||||
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNotAcceptable {
|
||||||
|
t.Errorf("Expected 406 without Accept header, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSSEStream(t *testing.T) {
|
||||||
|
transport, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
session, _ := transport.sessions.Create()
|
||||||
|
|
||||||
|
// Start SSE stream in goroutine
|
||||||
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||||
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != "text/event-stream" {
|
||||||
|
t.Errorf("Expected Content-Type text/event-stream, got %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a notification
|
||||||
|
notification := &Response{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 42,
|
||||||
|
Result: map[string]string{"test": "data"},
|
||||||
|
}
|
||||||
|
session.SendNotification(notification)
|
||||||
|
|
||||||
|
// Read the SSE event
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, err := resp.Body.Read(buf)
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
t.Fatalf("Failed to read SSE event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := string(buf[:n])
|
||||||
|
if !strings.HasPrefix(data, "data: ") {
|
||||||
|
t.Errorf("Expected SSE data event, got: %s", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the JSON from the SSE event
|
||||||
|
jsonData := strings.TrimPrefix(strings.TrimSuffix(data, "\n\n"), "data: ")
|
||||||
|
var received Response
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), &received); err != nil {
|
||||||
|
t.Fatalf("Failed to parse notification JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON unmarshal converts numbers to float64, so compare as float64
|
||||||
|
receivedID, ok := received.ID.(float64)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected numeric ID, got %T", received.ID)
|
||||||
|
}
|
||||||
|
if int(receivedID) != 42 {
|
||||||
|
t.Errorf("Expected notification ID 42, got %v", receivedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSSEKeepalive(t *testing.T) {
|
||||||
|
transport, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
SSEKeepAlive: 50 * time.Millisecond, // Short interval for testing
|
||||||
|
})
|
||||||
|
|
||||||
|
session, _ := transport.sessions.Create()
|
||||||
|
|
||||||
|
// Start SSE stream
|
||||||
|
req, _ := http.NewRequest("GET", ts.URL+"/mcp", nil)
|
||||||
|
req.Header.Set("Mcp-Session-Id", session.ID)
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("Expected 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read with timeout - should receive keepalive within 100ms
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
done := make(chan struct{})
|
||||||
|
var readData string
|
||||||
|
var readErr error
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
n, err := resp.Body.Read(buf)
|
||||||
|
readData = string(buf[:n])
|
||||||
|
readErr = err
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
if readErr != nil && readErr.Error() != "EOF" {
|
||||||
|
t.Fatalf("Read error: %v", readErr)
|
||||||
|
}
|
||||||
|
// Should receive SSE comment keepalive
|
||||||
|
if !strings.Contains(readData, ":keepalive") {
|
||||||
|
t.Errorf("Expected keepalive comment, got: %q", readData)
|
||||||
|
}
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
t.Error("Timeout waiting for keepalive")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSSEKeepaliveDisabled(t *testing.T) {
|
||||||
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
||||||
|
config := HTTPConfig{
|
||||||
|
SSEKeepAlive: -1, // Explicitly disabled
|
||||||
|
}
|
||||||
|
transport := NewHTTPTransport(server, config)
|
||||||
|
defer transport.sessions.Stop()
|
||||||
|
|
||||||
|
// When SSEKeepAlive is negative, it should remain negative (disabled)
|
||||||
|
if transport.config.SSEKeepAlive != -1 {
|
||||||
|
t.Errorf("Expected SSEKeepAlive to remain -1 (disabled), got %v", transport.config.SSEKeepAlive)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportParseError(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
// Send invalid JSON
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader([]byte("not json")))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 (with JSON-RPC error), got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jsonResp Response
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&jsonResp); err != nil {
|
||||||
|
t.Fatalf("Failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if jsonResp.Error == nil {
|
||||||
|
t.Error("Expected JSON-RPC error for parse error")
|
||||||
|
}
|
||||||
|
if jsonResp.Error != nil && jsonResp.Error.Code != ParseError {
|
||||||
|
t.Errorf("Expected parse error code %d, got %d", ParseError, jsonResp.Error.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportMethodNotAllowed(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("PUT", ts.URL+"/mcp", nil)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("Expected 405, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportOptionsRequest(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
AllowedOrigins: []string{"http://example.com"},
|
||||||
|
})
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("OPTIONS", ts.URL+"/mcp", nil)
|
||||||
|
req.Header.Set("Origin", "http://example.com")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusNoContent {
|
||||||
|
t.Errorf("Expected 204, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Header.Get("Access-Control-Allow-Origin") != "http://example.com" {
|
||||||
|
t.Error("Expected CORS origin header")
|
||||||
|
}
|
||||||
|
if resp.Header.Get("Access-Control-Allow-Methods") == "" {
|
||||||
|
t.Error("Expected CORS methods header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportDefaultConfig(t *testing.T) {
|
||||||
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
||||||
|
transport := NewHTTPTransport(server, HTTPConfig{})
|
||||||
|
|
||||||
|
// Verify defaults are applied
|
||||||
|
if transport.config.Address != "127.0.0.1:8080" {
|
||||||
|
t.Errorf("Expected default address 127.0.0.1:8080, got %s", transport.config.Address)
|
||||||
|
}
|
||||||
|
if transport.config.Endpoint != "/mcp" {
|
||||||
|
t.Errorf("Expected default endpoint /mcp, got %s", transport.config.Endpoint)
|
||||||
|
}
|
||||||
|
if transport.config.SessionTTL != 30*time.Minute {
|
||||||
|
t.Errorf("Expected default session TTL 30m, got %v", transport.config.SessionTTL)
|
||||||
|
}
|
||||||
|
if transport.config.MaxRequestSize != DefaultMaxRequestSize {
|
||||||
|
t.Errorf("Expected default max request size %d, got %d", DefaultMaxRequestSize, transport.config.MaxRequestSize)
|
||||||
|
}
|
||||||
|
if transport.config.ReadTimeout != DefaultReadTimeout {
|
||||||
|
t.Errorf("Expected default read timeout %v, got %v", DefaultReadTimeout, transport.config.ReadTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.WriteTimeout != DefaultWriteTimeout {
|
||||||
|
t.Errorf("Expected default write timeout %v, got %v", DefaultWriteTimeout, transport.config.WriteTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.IdleTimeout != DefaultIdleTimeout {
|
||||||
|
t.Errorf("Expected default idle timeout %v, got %v", DefaultIdleTimeout, transport.config.IdleTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.ReadHeaderTimeout != DefaultReadHeaderTimeout {
|
||||||
|
t.Errorf("Expected default read header timeout %v, got %v", DefaultReadHeaderTimeout, transport.config.ReadHeaderTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.SSEKeepAlive != DefaultSSEKeepAlive {
|
||||||
|
t.Errorf("Expected default SSE keepalive %v, got %v", DefaultSSEKeepAlive, transport.config.SSEKeepAlive)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport.sessions.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportCustomConfig(t *testing.T) {
|
||||||
|
server := NewServer(nil, log.New(io.Discard, "", 0))
|
||||||
|
config := HTTPConfig{
|
||||||
|
Address: "0.0.0.0:9090",
|
||||||
|
Endpoint: "/api/mcp",
|
||||||
|
SessionTTL: 1 * time.Hour,
|
||||||
|
MaxRequestSize: 5 << 20, // 5MB
|
||||||
|
ReadTimeout: 60 * time.Second,
|
||||||
|
WriteTimeout: 60 * time.Second,
|
||||||
|
IdleTimeout: 300 * time.Second,
|
||||||
|
ReadHeaderTimeout: 20 * time.Second,
|
||||||
|
}
|
||||||
|
transport := NewHTTPTransport(server, config)
|
||||||
|
|
||||||
|
// Verify custom values are preserved
|
||||||
|
if transport.config.Address != "0.0.0.0:9090" {
|
||||||
|
t.Errorf("Expected custom address, got %s", transport.config.Address)
|
||||||
|
}
|
||||||
|
if transport.config.Endpoint != "/api/mcp" {
|
||||||
|
t.Errorf("Expected custom endpoint, got %s", transport.config.Endpoint)
|
||||||
|
}
|
||||||
|
if transport.config.SessionTTL != 1*time.Hour {
|
||||||
|
t.Errorf("Expected custom session TTL, got %v", transport.config.SessionTTL)
|
||||||
|
}
|
||||||
|
if transport.config.MaxRequestSize != 5<<20 {
|
||||||
|
t.Errorf("Expected custom max request size, got %d", transport.config.MaxRequestSize)
|
||||||
|
}
|
||||||
|
if transport.config.ReadTimeout != 60*time.Second {
|
||||||
|
t.Errorf("Expected custom read timeout, got %v", transport.config.ReadTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.WriteTimeout != 60*time.Second {
|
||||||
|
t.Errorf("Expected custom write timeout, got %v", transport.config.WriteTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.IdleTimeout != 300*time.Second {
|
||||||
|
t.Errorf("Expected custom idle timeout, got %v", transport.config.IdleTimeout)
|
||||||
|
}
|
||||||
|
if transport.config.ReadHeaderTimeout != 20*time.Second {
|
||||||
|
t.Errorf("Expected custom read header timeout, got %v", transport.config.ReadHeaderTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport.sessions.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportRequestBodyTooLarge(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
MaxRequestSize: 100, // Very small limit for testing
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a request body larger than the limit
|
||||||
|
largeBody := make([]byte, 200)
|
||||||
|
for i := range largeBody {
|
||||||
|
largeBody[i] = 'x'
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(largeBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusRequestEntityTooLarge {
|
||||||
|
t.Errorf("Expected 413 for oversized request, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportSessionLimitReached(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
MaxSessions: 2, // Very low limit for testing
|
||||||
|
})
|
||||||
|
|
||||||
|
initReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodInitialize,
|
||||||
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(initReq)
|
||||||
|
|
||||||
|
// Create sessions up to the limit
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request %d failed: %v", i, err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Request %d: expected 200, got %d", i, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third request should fail with 503
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("Expected 503 when session limit reached, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTransportRequestBodyWithinLimit(t *testing.T) {
|
||||||
|
_, ts := testHTTPTransport(t, HTTPConfig{
|
||||||
|
MaxRequestSize: 10000, // Reasonable limit
|
||||||
|
})
|
||||||
|
|
||||||
|
// Send initialize request (should be well within limit)
|
||||||
|
initReq := Request{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: 1,
|
||||||
|
Method: MethodInitialize,
|
||||||
|
Params: json.RawMessage(`{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}`),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(initReq)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("POST", ts.URL+"/mcp", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Request failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected 200 for valid request within limit, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsLocalhostOrigin(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
origin string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"http://localhost", true},
|
||||||
|
{"http://localhost:3000", true},
|
||||||
|
{"https://localhost", true},
|
||||||
|
{"https://localhost:8443", true},
|
||||||
|
{"http://127.0.0.1", true},
|
||||||
|
{"http://127.0.0.1:8080", true},
|
||||||
|
{"https://127.0.0.1", true},
|
||||||
|
{"http://[::1]", true},
|
||||||
|
{"http://[::1]:8080", true},
|
||||||
|
{"https://[::1]", true},
|
||||||
|
{"http://example.com", false},
|
||||||
|
{"https://example.com", false},
|
||||||
|
{"http://localhost.evil.com", false},
|
||||||
|
{"http://192.168.1.1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.origin, func(t *testing.T) {
|
||||||
|
result := isLocalhostOrigin(tt.origin)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("isLocalhostOrigin(%q) = %v, want %v", tt.origin, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
63
internal/mcp/transport_stdio.go
Normal file
63
internal/mcp/transport_stdio.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StdioTransport implements the MCP protocol over STDIO using line-delimited JSON-RPC.
|
||||||
|
type StdioTransport struct {
|
||||||
|
server *Server
|
||||||
|
reader io.Reader
|
||||||
|
writer io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStdioTransport creates a new STDIO transport.
|
||||||
|
func NewStdioTransport(server *Server, r io.Reader, w io.Writer) *StdioTransport {
|
||||||
|
return &StdioTransport{
|
||||||
|
server: server,
|
||||||
|
reader: r,
|
||||||
|
writer: w,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run starts the STDIO transport, reading line-delimited JSON-RPC from the reader
|
||||||
|
// and writing responses to the writer.
|
||||||
|
func (t *StdioTransport) Run(ctx context.Context) error {
|
||||||
|
scanner := bufio.NewScanner(t.reader)
|
||||||
|
encoder := json.NewEncoder(t.writer)
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
line := scanner.Bytes()
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := t.server.HandleMessage(ctx, line)
|
||||||
|
if err != nil {
|
||||||
|
t.server.logger.Printf("Error handling message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp != nil {
|
||||||
|
if err := encoder.Encode(resp); err != nil {
|
||||||
|
return fmt.Errorf("failed to write response: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return fmt.Errorf("scanner error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -89,13 +89,56 @@ in
|
|||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
http = {
|
||||||
|
address = lib.mkOption {
|
||||||
|
type = lib.types.str;
|
||||||
|
default = "127.0.0.1:8080";
|
||||||
|
description = "HTTP listen address for the MCP server.";
|
||||||
|
};
|
||||||
|
|
||||||
|
endpoint = lib.mkOption {
|
||||||
|
type = lib.types.str;
|
||||||
|
default = "/mcp";
|
||||||
|
description = "HTTP endpoint path for MCP requests.";
|
||||||
|
};
|
||||||
|
|
||||||
|
allowedOrigins = lib.mkOption {
|
||||||
|
type = lib.types.listOf lib.types.str;
|
||||||
|
default = [ ];
|
||||||
|
example = [ "http://localhost:3000" "https://example.com" ];
|
||||||
|
description = ''
|
||||||
|
Allowed Origin headers for CORS.
|
||||||
|
Empty list means only localhost origins are allowed.
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
|
||||||
|
sessionTTL = lib.mkOption {
|
||||||
|
type = lib.types.str;
|
||||||
|
default = "30m";
|
||||||
|
description = "Session TTL for HTTP transport (Go duration format).";
|
||||||
|
};
|
||||||
|
|
||||||
|
tls = {
|
||||||
|
enable = lib.mkEnableOption "TLS for HTTP transport";
|
||||||
|
|
||||||
|
certFile = lib.mkOption {
|
||||||
|
type = lib.types.nullOr lib.types.path;
|
||||||
|
default = null;
|
||||||
|
description = "Path to TLS certificate file.";
|
||||||
|
};
|
||||||
|
|
||||||
|
keyFile = lib.mkOption {
|
||||||
|
type = lib.types.nullOr lib.types.path;
|
||||||
|
default = null;
|
||||||
|
description = "Path to TLS private key file.";
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
openFirewall = lib.mkOption {
|
openFirewall = lib.mkOption {
|
||||||
type = lib.types.bool;
|
type = lib.types.bool;
|
||||||
default = false;
|
default = false;
|
||||||
description = ''
|
description = "Whether to open the firewall for the MCP HTTP server.";
|
||||||
Whether to open the firewall for the MCP server.
|
|
||||||
Note: MCP typically runs over stdio, so this is usually not needed.
|
|
||||||
'';
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -111,6 +154,10 @@ in
|
|||||||
assertion = cfg.database.connectionString == "" || cfg.database.connectionStringFile == null;
|
assertion = cfg.database.connectionString == "" || cfg.database.connectionStringFile == null;
|
||||||
message = "services.nixos-options-mcp.database: connectionString and connectionStringFile are mutually exclusive";
|
message = "services.nixos-options-mcp.database: connectionString and connectionStringFile are mutually exclusive";
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
assertion = !cfg.http.tls.enable || (cfg.http.tls.certFile != null && cfg.http.tls.keyFile != null);
|
||||||
|
message = "services.nixos-options-mcp.http.tls: both certFile and keyFile must be set when TLS is enabled";
|
||||||
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
users.users.${cfg.user} = lib.mkIf (cfg.user == "nixos-options-mcp") {
|
users.users.${cfg.user} = lib.mkIf (cfg.user == "nixos-options-mcp") {
|
||||||
@@ -145,6 +192,19 @@ in
|
|||||||
nixos-options index "${rev}" || true
|
nixos-options index "${rev}" || true
|
||||||
'') cfg.indexOnStart}
|
'') cfg.indexOnStart}
|
||||||
'';
|
'';
|
||||||
|
|
||||||
|
# Build HTTP transport flags
|
||||||
|
httpFlags = lib.concatStringsSep " " ([
|
||||||
|
"--transport http"
|
||||||
|
"--http-address '${cfg.http.address}'"
|
||||||
|
"--http-endpoint '${cfg.http.endpoint}'"
|
||||||
|
"--session-ttl '${cfg.http.sessionTTL}'"
|
||||||
|
] ++ lib.optionals (cfg.http.allowedOrigins != []) (
|
||||||
|
map (origin: "--allowed-origins '${origin}'") cfg.http.allowedOrigins
|
||||||
|
) ++ lib.optionals cfg.http.tls.enable [
|
||||||
|
"--tls-cert '${cfg.http.tls.certFile}'"
|
||||||
|
"--tls-key '${cfg.http.tls.keyFile}'"
|
||||||
|
]);
|
||||||
in
|
in
|
||||||
if useConnectionStringFile then ''
|
if useConnectionStringFile then ''
|
||||||
# Read database connection string from file
|
# Read database connection string from file
|
||||||
@@ -155,10 +215,10 @@ in
|
|||||||
export NIXOS_OPTIONS_DATABASE="$(cat "${cfg.database.connectionStringFile}")"
|
export NIXOS_OPTIONS_DATABASE="$(cat "${cfg.database.connectionStringFile}")"
|
||||||
|
|
||||||
${indexCommands}
|
${indexCommands}
|
||||||
exec nixos-options serve
|
exec nixos-options serve ${httpFlags}
|
||||||
'' else ''
|
'' else ''
|
||||||
${indexCommands}
|
${indexCommands}
|
||||||
exec nixos-options serve
|
exec nixos-options serve ${httpFlags}
|
||||||
'';
|
'';
|
||||||
|
|
||||||
serviceConfig = {
|
serviceConfig = {
|
||||||
@@ -188,5 +248,14 @@ in
|
|||||||
StateDirectory = lib.mkIf (cfg.dataDir == "/var/lib/nixos-options-mcp") "nixos-options-mcp";
|
StateDirectory = lib.mkIf (cfg.dataDir == "/var/lib/nixos-options-mcp") "nixos-options-mcp";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
# Open firewall for HTTP port if configured
|
||||||
|
networking.firewall = lib.mkIf cfg.openFirewall (let
|
||||||
|
# Extract port from address (format: "host:port" or ":port")
|
||||||
|
addressParts = lib.splitString ":" cfg.http.address;
|
||||||
|
port = lib.toInt (lib.last addressParts);
|
||||||
|
in {
|
||||||
|
allowedTCPPorts = [ port ];
|
||||||
|
});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user