Add configurable MaxSessions limit (default: 10000) to SessionStore. When the limit is reached, new session creation returns ErrTooManySessions and HTTP transport responds with 503 Service Unavailable. This prevents attackers from exhausting server memory by creating unlimited sessions through repeated initialize requests. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
338 lines
7.4 KiB
Go
338 lines
7.4 KiB
Go
package mcp
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNewSession(t *testing.T) {
|
|
session, err := NewSession()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
|
|
if session.ID == "" {
|
|
t.Error("Session ID should not be empty")
|
|
}
|
|
if len(session.ID) != 32 {
|
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
|
|
}
|
|
if session.Initialized {
|
|
t.Error("New session should not be initialized")
|
|
}
|
|
}
|
|
|
|
func TestSessionTouch(t *testing.T) {
|
|
session, _ := NewSession()
|
|
originalActivity := session.LastActivity
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
session.Touch()
|
|
|
|
if !session.LastActivity.After(originalActivity) {
|
|
t.Error("Touch should update LastActivity")
|
|
}
|
|
}
|
|
|
|
func TestSessionInitialized(t *testing.T) {
|
|
session, _ := NewSession()
|
|
|
|
if session.IsInitialized() {
|
|
t.Error("New session should not be initialized")
|
|
}
|
|
|
|
session.SetInitialized()
|
|
|
|
if !session.IsInitialized() {
|
|
t.Error("Session should be initialized after SetInitialized")
|
|
}
|
|
}
|
|
|
|
func TestSessionNotifications(t *testing.T) {
|
|
session, _ := NewSession()
|
|
defer session.Close()
|
|
|
|
notification := &Response{JSONRPC: "2.0", ID: 1}
|
|
|
|
if !session.SendNotification(notification) {
|
|
t.Error("SendNotification should return true on success")
|
|
}
|
|
|
|
select {
|
|
case received := <-session.Notifications():
|
|
if received.ID != notification.ID {
|
|
t.Error("Received notification should match sent")
|
|
}
|
|
case <-time.After(100 * time.Millisecond):
|
|
t.Error("Should receive notification")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreCreate(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
session, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
|
|
if store.Count() != 1 {
|
|
t.Errorf("Store should have 1 session, got %d", store.Count())
|
|
}
|
|
|
|
// Verify we can retrieve it
|
|
retrieved := store.Get(session.ID)
|
|
if retrieved == nil {
|
|
t.Error("Should be able to retrieve created session")
|
|
}
|
|
if retrieved.ID != session.ID {
|
|
t.Error("Retrieved session ID should match")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreGet(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
// Get non-existent session
|
|
if store.Get("nonexistent") != nil {
|
|
t.Error("Should return nil for non-existent session")
|
|
}
|
|
|
|
// Create and retrieve
|
|
session, _ := store.Create()
|
|
retrieved := store.Get(session.ID)
|
|
if retrieved == nil {
|
|
t.Error("Should find created session")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreDelete(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
if store.Count() != 1 {
|
|
t.Error("Should have 1 session after create")
|
|
}
|
|
|
|
if !store.Delete(session.ID) {
|
|
t.Error("Delete should return true for existing session")
|
|
}
|
|
|
|
if store.Count() != 0 {
|
|
t.Error("Should have 0 sessions after delete")
|
|
}
|
|
|
|
if store.Delete(session.ID) {
|
|
t.Error("Delete should return false for non-existent session")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreTTLExpiration(t *testing.T) {
|
|
ttl := 50 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
|
|
// Should be retrievable immediately
|
|
if store.Get(session.ID) == nil {
|
|
t.Error("Session should be retrievable immediately")
|
|
}
|
|
|
|
// Wait for expiration
|
|
time.Sleep(ttl + 10*time.Millisecond)
|
|
|
|
// Should not be retrievable after TTL
|
|
if store.Get(session.ID) != nil {
|
|
t.Error("Expired session should not be retrievable")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreTTLRefresh(t *testing.T) {
|
|
ttl := 100 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
|
|
// Touch the session before TTL expires
|
|
time.Sleep(60 * time.Millisecond)
|
|
session.Touch()
|
|
|
|
// Wait past original TTL but not past refreshed TTL
|
|
time.Sleep(60 * time.Millisecond)
|
|
|
|
// Should still be retrievable because we touched it
|
|
if store.Get(session.ID) == nil {
|
|
t.Error("Touched session should still be retrievable")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreCleanup(t *testing.T) {
|
|
ttl := 50 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
// Create multiple sessions
|
|
for i := 0; i < 5; i++ {
|
|
store.Create()
|
|
}
|
|
|
|
if store.Count() != 5 {
|
|
t.Errorf("Should have 5 sessions, got %d", store.Count())
|
|
}
|
|
|
|
// Wait for cleanup to run (runs at ttl/2 intervals)
|
|
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
|
|
|
|
// All sessions should be cleaned up
|
|
if store.Count() != 0 {
|
|
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreConcurrency(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
sessionIDs := make(chan string, 100)
|
|
|
|
// Create sessions concurrently
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
session, err := store.Create()
|
|
if err != nil {
|
|
t.Errorf("Failed to create session: %v", err)
|
|
return
|
|
}
|
|
sessionIDs <- session.ID
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(sessionIDs)
|
|
|
|
// Verify all sessions were created
|
|
if store.Count() != 50 {
|
|
t.Errorf("Should have 50 sessions, got %d", store.Count())
|
|
}
|
|
|
|
// Read and delete concurrently
|
|
var ids []string
|
|
for id := range sessionIDs {
|
|
ids = append(ids, id)
|
|
}
|
|
|
|
for _, id := range ids {
|
|
wg.Add(2)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
store.Get(id)
|
|
}(id)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
store.Delete(id)
|
|
}(id)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestSessionStoreMaxSessions(t *testing.T) {
|
|
maxSessions := 5
|
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
|
defer store.Stop()
|
|
|
|
// Create sessions up to limit
|
|
for i := 0; i < maxSessions; i++ {
|
|
_, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
if store.Count() != maxSessions {
|
|
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
|
|
}
|
|
|
|
// Try to create one more - should fail
|
|
_, err := store.Create()
|
|
if err != ErrTooManySessions {
|
|
t.Errorf("Expected ErrTooManySessions, got %v", err)
|
|
}
|
|
|
|
// Count should still be at max
|
|
if store.Count() != maxSessions {
|
|
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
|
|
maxSessions := 3
|
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
|
defer store.Stop()
|
|
|
|
// Fill up the store
|
|
sessions := make([]*Session, maxSessions)
|
|
for i := 0; i < maxSessions; i++ {
|
|
s, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
sessions[i] = s
|
|
}
|
|
|
|
// Should be full
|
|
_, err := store.Create()
|
|
if err != ErrTooManySessions {
|
|
t.Error("Expected ErrTooManySessions when full")
|
|
}
|
|
|
|
// Delete one session
|
|
store.Delete(sessions[0].ID)
|
|
|
|
// Should be able to create again
|
|
_, err = store.Create()
|
|
if err != nil {
|
|
t.Errorf("Should be able to create after deletion: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
// Just verify it uses the default (don't create 10000 sessions)
|
|
if store.maxSessions != DefaultMaxSessions {
|
|
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
|
|
}
|
|
}
|
|
|
|
func TestGenerateSessionID(t *testing.T) {
|
|
ids := make(map[string]bool)
|
|
|
|
// Generate 1000 IDs and ensure uniqueness
|
|
for i := 0; i < 1000; i++ {
|
|
id, err := generateSessionID()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate session ID: %v", err)
|
|
}
|
|
|
|
if len(id) != 32 {
|
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
|
|
}
|
|
|
|
if ids[id] {
|
|
t.Error("Generated duplicate session ID")
|
|
}
|
|
ids[id] = true
|
|
}
|
|
}
|