This repository has been archived on 2026-03-10. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
labmcp/internal/mcp/session_test.go
Torjus Håkestad 684baf63da security: add maximum session limit to prevent memory exhaustion
Add configurable MaxSessions limit (default: 10000) to SessionStore.
When the limit is reached, new session creation returns ErrTooManySessions
and HTTP transport responds with 503 Service Unavailable.

This prevents attackers from exhausting server memory by creating
unlimited sessions through repeated initialize requests.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 22:07:51 +01:00

338 lines
7.4 KiB
Go

package mcp
import (
"sync"
"testing"
"time"
)
func TestNewSession(t *testing.T) {
session, err := NewSession()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if session.ID == "" {
t.Error("Session ID should not be empty")
}
if len(session.ID) != 32 {
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
}
if session.Initialized {
t.Error("New session should not be initialized")
}
}
func TestSessionTouch(t *testing.T) {
session, _ := NewSession()
originalActivity := session.LastActivity
time.Sleep(10 * time.Millisecond)
session.Touch()
if !session.LastActivity.After(originalActivity) {
t.Error("Touch should update LastActivity")
}
}
func TestSessionInitialized(t *testing.T) {
session, _ := NewSession()
if session.IsInitialized() {
t.Error("New session should not be initialized")
}
session.SetInitialized()
if !session.IsInitialized() {
t.Error("Session should be initialized after SetInitialized")
}
}
func TestSessionNotifications(t *testing.T) {
session, _ := NewSession()
defer session.Close()
notification := &Response{JSONRPC: "2.0", ID: 1}
if !session.SendNotification(notification) {
t.Error("SendNotification should return true on success")
}
select {
case received := <-session.Notifications():
if received.ID != notification.ID {
t.Error("Received notification should match sent")
}
case <-time.After(100 * time.Millisecond):
t.Error("Should receive notification")
}
}
func TestSessionStoreCreate(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
session, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
if store.Count() != 1 {
t.Errorf("Store should have 1 session, got %d", store.Count())
}
// Verify we can retrieve it
retrieved := store.Get(session.ID)
if retrieved == nil {
t.Error("Should be able to retrieve created session")
}
if retrieved.ID != session.ID {
t.Error("Retrieved session ID should match")
}
}
func TestSessionStoreGet(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Get non-existent session
if store.Get("nonexistent") != nil {
t.Error("Should return nil for non-existent session")
}
// Create and retrieve
session, _ := store.Create()
retrieved := store.Get(session.ID)
if retrieved == nil {
t.Error("Should find created session")
}
}
func TestSessionStoreDelete(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
session, _ := store.Create()
if store.Count() != 1 {
t.Error("Should have 1 session after create")
}
if !store.Delete(session.ID) {
t.Error("Delete should return true for existing session")
}
if store.Count() != 0 {
t.Error("Should have 0 sessions after delete")
}
if store.Delete(session.ID) {
t.Error("Delete should return false for non-existent session")
}
}
func TestSessionStoreTTLExpiration(t *testing.T) {
ttl := 50 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
session, _ := store.Create()
// Should be retrievable immediately
if store.Get(session.ID) == nil {
t.Error("Session should be retrievable immediately")
}
// Wait for expiration
time.Sleep(ttl + 10*time.Millisecond)
// Should not be retrievable after TTL
if store.Get(session.ID) != nil {
t.Error("Expired session should not be retrievable")
}
}
func TestSessionStoreTTLRefresh(t *testing.T) {
ttl := 100 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
session, _ := store.Create()
// Touch the session before TTL expires
time.Sleep(60 * time.Millisecond)
session.Touch()
// Wait past original TTL but not past refreshed TTL
time.Sleep(60 * time.Millisecond)
// Should still be retrievable because we touched it
if store.Get(session.ID) == nil {
t.Error("Touched session should still be retrievable")
}
}
func TestSessionStoreCleanup(t *testing.T) {
ttl := 50 * time.Millisecond
store := NewSessionStore(ttl)
defer store.Stop()
// Create multiple sessions
for i := 0; i < 5; i++ {
store.Create()
}
if store.Count() != 5 {
t.Errorf("Should have 5 sessions, got %d", store.Count())
}
// Wait for cleanup to run (runs at ttl/2 intervals)
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
// All sessions should be cleaned up
if store.Count() != 0 {
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
}
}
func TestSessionStoreConcurrency(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
var wg sync.WaitGroup
sessionIDs := make(chan string, 100)
// Create sessions concurrently
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
session, err := store.Create()
if err != nil {
t.Errorf("Failed to create session: %v", err)
return
}
sessionIDs <- session.ID
}()
}
wg.Wait()
close(sessionIDs)
// Verify all sessions were created
if store.Count() != 50 {
t.Errorf("Should have 50 sessions, got %d", store.Count())
}
// Read and delete concurrently
var ids []string
for id := range sessionIDs {
ids = append(ids, id)
}
for _, id := range ids {
wg.Add(2)
go func(id string) {
defer wg.Done()
store.Get(id)
}(id)
go func(id string) {
defer wg.Done()
store.Delete(id)
}(id)
}
wg.Wait()
}
func TestSessionStoreMaxSessions(t *testing.T) {
maxSessions := 5
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Create sessions up to limit
for i := 0; i < maxSessions; i++ {
_, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session %d: %v", i, err)
}
}
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
}
// Try to create one more - should fail
_, err := store.Create()
if err != ErrTooManySessions {
t.Errorf("Expected ErrTooManySessions, got %v", err)
}
// Count should still be at max
if store.Count() != maxSessions {
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
}
}
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
maxSessions := 3
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
defer store.Stop()
// Fill up the store
sessions := make([]*Session, maxSessions)
for i := 0; i < maxSessions; i++ {
s, err := store.Create()
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
sessions[i] = s
}
// Should be full
_, err := store.Create()
if err != ErrTooManySessions {
t.Error("Expected ErrTooManySessions when full")
}
// Delete one session
store.Delete(sessions[0].ID)
// Should be able to create again
_, err = store.Create()
if err != nil {
t.Errorf("Should be able to create after deletion: %v", err)
}
}
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
store := NewSessionStore(30 * time.Minute)
defer store.Stop()
// Just verify it uses the default (don't create 10000 sessions)
if store.maxSessions != DefaultMaxSessions {
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
}
}
func TestGenerateSessionID(t *testing.T) {
ids := make(map[string]bool)
// Generate 1000 IDs and ensure uniqueness
for i := 0; i < 1000; i++ {
id, err := generateSessionID()
if err != nil {
t.Fatalf("Failed to generate session ID: %v", err)
}
if len(id) != 32 {
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
}
if ids[id] {
t.Error("Generated duplicate session ID")
}
ids[id] = true
}
}