Use t.Fatal instead of t.Error when retrieved session is nil to prevent subsequent nil pointer dereference on retrieved.ID. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
338 lines
7.4 KiB
Go
338 lines
7.4 KiB
Go
package mcp
|
|
|
|
import (
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestNewSession(t *testing.T) {
|
|
session, err := NewSession()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
|
|
if session.ID == "" {
|
|
t.Error("Session ID should not be empty")
|
|
}
|
|
if len(session.ID) != 32 {
|
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(session.ID))
|
|
}
|
|
if session.Initialized {
|
|
t.Error("New session should not be initialized")
|
|
}
|
|
}
|
|
|
|
func TestSessionTouch(t *testing.T) {
|
|
session, _ := NewSession()
|
|
originalActivity := session.LastActivity
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
session.Touch()
|
|
|
|
if !session.LastActivity.After(originalActivity) {
|
|
t.Error("Touch should update LastActivity")
|
|
}
|
|
}
|
|
|
|
func TestSessionInitialized(t *testing.T) {
|
|
session, _ := NewSession()
|
|
|
|
if session.IsInitialized() {
|
|
t.Error("New session should not be initialized")
|
|
}
|
|
|
|
session.SetInitialized()
|
|
|
|
if !session.IsInitialized() {
|
|
t.Error("Session should be initialized after SetInitialized")
|
|
}
|
|
}
|
|
|
|
func TestSessionNotifications(t *testing.T) {
|
|
session, _ := NewSession()
|
|
defer session.Close()
|
|
|
|
notification := &Response{JSONRPC: "2.0", ID: 1}
|
|
|
|
if !session.SendNotification(notification) {
|
|
t.Error("SendNotification should return true on success")
|
|
}
|
|
|
|
select {
|
|
case received := <-session.Notifications():
|
|
if received.ID != notification.ID {
|
|
t.Error("Received notification should match sent")
|
|
}
|
|
case <-time.After(100 * time.Millisecond):
|
|
t.Error("Should receive notification")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreCreate(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
session, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
|
|
if store.Count() != 1 {
|
|
t.Errorf("Store should have 1 session, got %d", store.Count())
|
|
}
|
|
|
|
// Verify we can retrieve it
|
|
retrieved := store.Get(session.ID)
|
|
if retrieved == nil {
|
|
t.Fatal("Should be able to retrieve created session")
|
|
}
|
|
if retrieved.ID != session.ID {
|
|
t.Error("Retrieved session ID should match")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreGet(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
// Get non-existent session
|
|
if store.Get("nonexistent") != nil {
|
|
t.Error("Should return nil for non-existent session")
|
|
}
|
|
|
|
// Create and retrieve
|
|
session, _ := store.Create()
|
|
retrieved := store.Get(session.ID)
|
|
if retrieved == nil {
|
|
t.Error("Should find created session")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreDelete(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
if store.Count() != 1 {
|
|
t.Error("Should have 1 session after create")
|
|
}
|
|
|
|
if !store.Delete(session.ID) {
|
|
t.Error("Delete should return true for existing session")
|
|
}
|
|
|
|
if store.Count() != 0 {
|
|
t.Error("Should have 0 sessions after delete")
|
|
}
|
|
|
|
if store.Delete(session.ID) {
|
|
t.Error("Delete should return false for non-existent session")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreTTLExpiration(t *testing.T) {
|
|
ttl := 50 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
|
|
// Should be retrievable immediately
|
|
if store.Get(session.ID) == nil {
|
|
t.Error("Session should be retrievable immediately")
|
|
}
|
|
|
|
// Wait for expiration
|
|
time.Sleep(ttl + 10*time.Millisecond)
|
|
|
|
// Should not be retrievable after TTL
|
|
if store.Get(session.ID) != nil {
|
|
t.Error("Expired session should not be retrievable")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreTTLRefresh(t *testing.T) {
|
|
ttl := 100 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
session, _ := store.Create()
|
|
|
|
// Touch the session before TTL expires
|
|
time.Sleep(60 * time.Millisecond)
|
|
session.Touch()
|
|
|
|
// Wait past original TTL but not past refreshed TTL
|
|
time.Sleep(60 * time.Millisecond)
|
|
|
|
// Should still be retrievable because we touched it
|
|
if store.Get(session.ID) == nil {
|
|
t.Error("Touched session should still be retrievable")
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreCleanup(t *testing.T) {
|
|
ttl := 50 * time.Millisecond
|
|
store := NewSessionStore(ttl)
|
|
defer store.Stop()
|
|
|
|
// Create multiple sessions
|
|
for i := 0; i < 5; i++ {
|
|
store.Create()
|
|
}
|
|
|
|
if store.Count() != 5 {
|
|
t.Errorf("Should have 5 sessions, got %d", store.Count())
|
|
}
|
|
|
|
// Wait for cleanup to run (runs at ttl/2 intervals)
|
|
time.Sleep(ttl + ttl/2 + 10*time.Millisecond)
|
|
|
|
// All sessions should be cleaned up
|
|
if store.Count() != 0 {
|
|
t.Errorf("All sessions should be cleaned up, got %d", store.Count())
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreConcurrency(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
var wg sync.WaitGroup
|
|
sessionIDs := make(chan string, 100)
|
|
|
|
// Create sessions concurrently
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
session, err := store.Create()
|
|
if err != nil {
|
|
t.Errorf("Failed to create session: %v", err)
|
|
return
|
|
}
|
|
sessionIDs <- session.ID
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
close(sessionIDs)
|
|
|
|
// Verify all sessions were created
|
|
if store.Count() != 50 {
|
|
t.Errorf("Should have 50 sessions, got %d", store.Count())
|
|
}
|
|
|
|
// Read and delete concurrently
|
|
var ids []string
|
|
for id := range sessionIDs {
|
|
ids = append(ids, id)
|
|
}
|
|
|
|
for _, id := range ids {
|
|
wg.Add(2)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
store.Get(id)
|
|
}(id)
|
|
go func(id string) {
|
|
defer wg.Done()
|
|
store.Delete(id)
|
|
}(id)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestSessionStoreMaxSessions(t *testing.T) {
|
|
maxSessions := 5
|
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
|
defer store.Stop()
|
|
|
|
// Create sessions up to limit
|
|
for i := 0; i < maxSessions; i++ {
|
|
_, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
if store.Count() != maxSessions {
|
|
t.Errorf("Expected %d sessions, got %d", maxSessions, store.Count())
|
|
}
|
|
|
|
// Try to create one more - should fail
|
|
_, err := store.Create()
|
|
if err != ErrTooManySessions {
|
|
t.Errorf("Expected ErrTooManySessions, got %v", err)
|
|
}
|
|
|
|
// Count should still be at max
|
|
if store.Count() != maxSessions {
|
|
t.Errorf("Expected %d sessions after failed create, got %d", maxSessions, store.Count())
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreMaxSessionsWithDeletion(t *testing.T) {
|
|
maxSessions := 3
|
|
store := NewSessionStoreWithLimit(30*time.Minute, maxSessions)
|
|
defer store.Stop()
|
|
|
|
// Fill up the store
|
|
sessions := make([]*Session, maxSessions)
|
|
for i := 0; i < maxSessions; i++ {
|
|
s, err := store.Create()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create session: %v", err)
|
|
}
|
|
sessions[i] = s
|
|
}
|
|
|
|
// Should be full
|
|
_, err := store.Create()
|
|
if err != ErrTooManySessions {
|
|
t.Error("Expected ErrTooManySessions when full")
|
|
}
|
|
|
|
// Delete one session
|
|
store.Delete(sessions[0].ID)
|
|
|
|
// Should be able to create again
|
|
_, err = store.Create()
|
|
if err != nil {
|
|
t.Errorf("Should be able to create after deletion: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSessionStoreDefaultMaxSessions(t *testing.T) {
|
|
store := NewSessionStore(30 * time.Minute)
|
|
defer store.Stop()
|
|
|
|
// Just verify it uses the default (don't create 10000 sessions)
|
|
if store.maxSessions != DefaultMaxSessions {
|
|
t.Errorf("Expected default max sessions %d, got %d", DefaultMaxSessions, store.maxSessions)
|
|
}
|
|
}
|
|
|
|
func TestGenerateSessionID(t *testing.T) {
|
|
ids := make(map[string]bool)
|
|
|
|
// Generate 1000 IDs and ensure uniqueness
|
|
for i := 0; i < 1000; i++ {
|
|
id, err := generateSessionID()
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate session ID: %v", err)
|
|
}
|
|
|
|
if len(id) != 32 {
|
|
t.Errorf("Session ID should be 32 hex chars, got %d", len(id))
|
|
}
|
|
|
|
if ids[id] {
|
|
t.Error("Generated duplicate session ID")
|
|
}
|
|
ids[id] = true
|
|
}
|
|
}
|