Add support for running the MCP server over HTTP with Server-Sent Events (SSE) using the MCP Streamable HTTP specification, alongside the existing STDIO transport. New features: - Transport abstraction with Transport interface - HTTP transport with session management - SSE support for server-initiated notifications - CORS security with configurable allowed origins - Optional TLS support - CLI flags for HTTP configuration (--transport, --http-address, etc.) - NixOS module options for HTTP transport The HTTP transport implements: - POST /mcp: JSON-RPC requests with session management - GET /mcp: SSE stream for server notifications - DELETE /mcp: Session termination - Origin validation (localhost-only by default) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
268 lines
5.6 KiB
Go
268 lines
5.6 KiB
Go
package mcp
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNewSession(t *testing.T) {
|
|
session, err := NewSession()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
|
|
if session.ID == "" {
|
|
t.Error("Session ID should not be empty")
|
|
}
|
|
if len(session.ID) != 32 {
|
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
|
|
}
|
|
if session.Initialized {
|
|
t.Error("New session should not be initialized")
|
|
}
|
|
}
|
|
|
|
func TestSessionTouch(t *testing.T) {
|
|
session, _ := NewSession()
|
|
originalActivity := session.LastActivity
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
session.Touch()
|
|
|
|
if !session.LastActivity.After(originalActivity) {
|
|
t.Error("Touch should update LastActivity")
|
|
}
|
|
}
|
|
|
|
func TestSessionInitialized(t *testing.T) {
|
|
session, _ := NewSession()
|
|
|
|
if session.IsInitialized() {
|
|
t.Error("New session should not be initialized")
|
|
}
|
|
|
|
session.SetInitialized()
|
|
|
|
if !session.IsInitialized() {
|
|
t.Error("Session should be initialized after SetInitialized")
|
|
}
|
|
}
|
|
|
|
func TestSessionNotifications(t *testing.T) {
|
|
session, _ := NewSession()
|
|
defer session.Close()
|
|
|
|
notification := &Response{JSONRPC: "2.0", ID: 1}
|
|
|
|
if !session.SendNotification(notification) {
|
|
t.Error("SendNotification should return true on success")
|
|
}
|
|
|
|
select {
|
|
case received := <-session.Notifications():
|
|
if received.ID != notification.ID {
|
|
t.Error("Received notification should match sent")
|
|
}
|
|
case <-time.After(100 * time.Millisecond):
|
|
t.Error("Should receive notification")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreCreate(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
session, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
|
|
if store.Count() != 1 {
|
|
t.Errorf("Store should have 1 session, got %d", store.Count())
|
|
}
|
|
|
|
// Verify we can retrieve it
|
|
retrieved := store.Get(session.ID)
|
|
if retrieved == nil {
|
|
t.Error("Should be able to retrieve created session")
|
|
}
|
|
if retrieved.ID != session.ID {
|
|
t.Error("Retrieved session ID should match")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreGet(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
// Get non-existent session
|
|
if store.Get("nonexistent") != nil {
|
|
t.Error("Should return nil for non-existent session")
|
|
}
|
|
|
|
// Create and retrieve
|
|
session, _ := store.Create()
|
|
retrieved := store.Get(session.ID)
|
|
if retrieved == nil {
|
|
t.Error("Should find created session")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreDelete(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
if store.Count() != 1 {
|
|
t.Error("Should have 1 session after create")
|
|
}
|
|
|
|
if !store.Delete(session.ID) {
|
|
t.Error("Delete should return true for existing session")
|
|
}
|
|
|
|
if store.Count() != 0 {
|
|
t.Error("Should have 0 sessions after delete")
|
|
}
|
|
|
|
if store.Delete(session.ID) {
|
|
t.Error("Delete should return false for non-existent session")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreTTLExpiration(t *testing.T) {
|
|
ttl := 50 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
|
|
// Should be retrievable immediately
|
|
if store.Get(session.ID) == nil {
|
|
t.Error("Session should be retrievable immediately")
|
|
}
|
|
|
|
// Wait for expiration
|
|
time.Sleep(ttl + 10*time.Millisecond)
|
|
|
|
// Should not be retrievable after TTL
|
|
if store.Get(session.ID) != nil {
|
|
t.Error("Expired session should not be retrievable")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreTTLRefresh(t *testing.T) {
|
|
ttl := 100 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
|
|
// Touch the session before TTL expires
|
|
time.Sleep(60 * time.Millisecond)
|
|
session.Touch()
|
|
|
|
// Wait past original TTL but not past refreshed TTL
|
|
time.Sleep(60 * time.Millisecond)
|
|
|
|
// Should still be retrievable because we touched it
|
|
if store.Get(session.ID) == nil {
|
|
t.Error("Touched session should still be retrievable")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreCleanup(t *testing.T) {
|
|
ttl := 50 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
// Create multiple sessions
|
|
for i := 0; i < 5; i++ {
|
|
store.Create()
|
|
}
|
|
|
|
if store.Count() != 5 {
|
|
t.Errorf("Should have 5 sessions, got %d", store.Count())
|
|
}
|
|
|
|
// Wait for cleanup to run (runs at ttl/2 intervals)
|
|
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
|
|
|
|
// All sessions should be cleaned up
|
|
if store.Count() != 0 {
|
|
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreConcurrency(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
sessionIDs := make(chan string, 100)
|
|
|
|
// Create sessions concurrently
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
session, err := store.Create()
|
|
if err != nil {
|
|
t.Errorf("Failed to create session: %v", err)
|
|
return
|
|
}
|
|
sessionIDs <- session.ID
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(sessionIDs)
|
|
|
|
// Verify all sessions were created
|
|
if store.Count() != 50 {
|
|
t.Errorf("Should have 50 sessions, got %d", store.Count())
|
|
}
|
|
|
|
// Read and delete concurrently
|
|
var ids []string
|
|
for id := range sessionIDs {
|
|
ids = append(ids, id)
|
|
}
|
|
|
|
for _, id := range ids {
|
|
wg.Add(2)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
store.Get(id)
|
|
}(id)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
store.Delete(id)
|
|
}(id)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestGenerateSessionID(t *testing.T) {
|
|
ids := make(map[string]bool)
|
|
|
|
// Generate 1000 IDs and ensure uniqueness
|
|
for i := 0; i < 1000; i++ {
|
|
id, err := generateSessionID()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate session ID: %v", err)
|
|
}
|
|
|
|
if len(id) != 32 {
|
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
|
|
}
|
|
|
|
if ids[id] {
|
|
t.Error("Generated duplicate session ID")
|
|
}
|
|
ids[id] = true
|
|
}
|
|
}
|