package storage import ( "context" "path/filepath" "testing" "time" ) // storeFactory returns a clean Store and a cleanup function. type storeFactory func(t *testing.T) Store func testStores(t *testing.T, f func(t *testing.T, newStore storeFactory)) { t.Helper() t.Run("SQLite", func(t *testing.T) { f(t, func(t *testing.T) Store { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") s, err := NewSQLiteStore(dbPath) if err != nil { t.Fatalf("creating SQLiteStore: %v", err) } t.Cleanup(func() { _ = s.Close() }) return s }) }) t.Run("Memory", func(t *testing.T) { f(t, func(t *testing.T) Store { t.Helper() return NewMemoryStore() }) }) } func seedData(t *testing.T, store Store) { t.Helper() ctx := context.Background() // Login attempts: root/toor from two IPs, admin/admin from one IP. for range 5 { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil { t.Fatalf("seeding attempt: %v", err) } } for range 3 { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil { t.Fatalf("seeding attempt: %v", err) } } for range 2 { if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil { t.Fatalf("seeding attempt: %v", err) } } // Sessions: one active, one ended. id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("creating session: %v", err) } if err := store.EndSession(ctx, id1, time.Now()); err != nil { t.Fatalf("ending session: %v", err) } if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil { t.Fatalf("creating session: %v", err) } } func TestGetDashboardStats(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("empty", func(t *testing.T) { store := newStore(t) ctx := context.Background() stats, err := store.GetDashboardStats(ctx) if err != nil { t.Fatalf("GetDashboardStats: %v", err) } if stats.TotalAttempts != 0 || stats.UniqueIPs != 0 || stats.TotalSessions != 0 || stats.ActiveSessions != 0 { t.Errorf("expected all zeros, got %+v", stats) } }) t.Run("with data", func(t *testing.T) { store := newStore(t) seedData(t, store) ctx := context.Background() stats, err := store.GetDashboardStats(ctx) if err != nil { t.Fatalf("GetDashboardStats: %v", err) } // 5 + 3 + 2 = 10 total attempts if stats.TotalAttempts != 10 { t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts) } // 2 unique IPs: 10.0.0.1 and 10.0.0.2 if stats.UniqueIPs != 2 { t.Errorf("UniqueIPs = %d, want 2", stats.UniqueIPs) } if stats.TotalSessions != 2 { t.Errorf("TotalSessions = %d, want 2", stats.TotalSessions) } if stats.ActiveSessions != 1 { t.Errorf("ActiveSessions = %d, want 1", stats.ActiveSessions) } }) }) } func TestGetTopUsernames(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("empty", func(t *testing.T) { store := newStore(t) entries, err := store.GetTopUsernames(context.Background(), 10) if err != nil { t.Fatalf("GetTopUsernames: %v", err) } if len(entries) != 0 { t.Errorf("expected empty, got %v", entries) } }) t.Run("with data", func(t *testing.T) { store := newStore(t) seedData(t, store) entries, err := store.GetTopUsernames(context.Background(), 10) if err != nil { t.Fatalf("GetTopUsernames: %v", err) } if len(entries) != 2 { t.Fatalf("len = %d, want 2", len(entries)) } // root: 5 + 3 = 8, admin: 2 if entries[0].Value != "root" || entries[0].Count != 8 { t.Errorf("entries[0] = %+v, want root/8", entries[0]) } if entries[1].Value != "admin" || entries[1].Count != 2 { t.Errorf("entries[1] = %+v, want admin/2", entries[1]) } }) t.Run("limit", func(t *testing.T) { store := newStore(t) seedData(t, store) entries, err := store.GetTopUsernames(context.Background(), 1) if err != nil { t.Fatalf("GetTopUsernames: %v", err) } if len(entries) != 1 { t.Fatalf("len = %d, want 1", len(entries)) } }) }) } func TestGetTopPasswords(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { store := newStore(t) seedData(t, store) entries, err := store.GetTopPasswords(context.Background(), 10) if err != nil { t.Fatalf("GetTopPasswords: %v", err) } if len(entries) != 2 { t.Fatalf("len = %d, want 2", len(entries)) } // toor: 8, admin: 2 if entries[0].Value != "toor" || entries[0].Count != 8 { t.Errorf("entries[0] = %+v, want toor/8", entries[0]) } }) } func TestGetTopIPs(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { store := newStore(t) seedData(t, store) entries, err := store.GetTopIPs(context.Background(), 10) if err != nil { t.Fatalf("GetTopIPs: %v", err) } if len(entries) != 2 { t.Fatalf("len = %d, want 2", len(entries)) } // 10.0.0.1: 5 + 2 = 7, 10.0.0.2: 3 if entries[0].Value != "10.0.0.1" || entries[0].Count != 7 { t.Errorf("entries[0] = %+v, want 10.0.0.1/7", entries[0]) } }) } func TestGetSession(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("not found", func(t *testing.T) { store := newStore(t) s, err := store.GetSession(context.Background(), "nonexistent") if err != nil { t.Fatalf("GetSession: %v", err) } if s != nil { t.Errorf("expected nil, got %+v", s) } }) t.Run("found", func(t *testing.T) { store := newStore(t) ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("CreateSession: %v", err) } s, err := store.GetSession(ctx, id) if err != nil { t.Fatalf("GetSession: %v", err) } if s == nil { t.Fatal("expected session, got nil") } if s.ID != id || s.IP != "10.0.0.1" || s.Username != "root" || s.ShellName != "bash" { t.Errorf("unexpected session: %+v", s) } }) }) } func TestGetSessionLogs(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { store := newStore(t) ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("CreateSession: %v", err) } if err := store.AppendSessionLog(ctx, id, "ls", "file1\nfile2"); err != nil { t.Fatalf("AppendSessionLog: %v", err) } if err := store.AppendSessionLog(ctx, id, "pwd", "/home/root"); err != nil { t.Fatalf("AppendSessionLog: %v", err) } logs, err := store.GetSessionLogs(ctx, id) if err != nil { t.Fatalf("GetSessionLogs: %v", err) } if len(logs) != 2 { t.Fatalf("len = %d, want 2", len(logs)) } if logs[0].Input != "ls" { t.Errorf("logs[0].Input = %q, want %q", logs[0].Input, "ls") } if logs[1].Input != "pwd" { t.Errorf("logs[1].Input = %q, want %q", logs[1].Input, "pwd") } }) } func TestSessionEvents(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("empty", func(t *testing.T) { store := newStore(t) events, err := store.GetSessionEvents(context.Background(), "nonexistent") if err != nil { t.Fatalf("GetSessionEvents: %v", err) } if len(events) != 0 { t.Errorf("expected empty, got %d", len(events)) } }) t.Run("append and retrieve", func(t *testing.T) { store := newStore(t) ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("CreateSession: %v", err) } now := time.Now().UTC() events := []SessionEvent{ {SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, {SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, {SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, } if err := store.AppendSessionEvents(ctx, events); err != nil { t.Fatalf("AppendSessionEvents: %v", err) } got, err := store.GetSessionEvents(ctx, id) if err != nil { t.Fatalf("GetSessionEvents: %v", err) } if len(got) != 3 { t.Fatalf("len = %d, want 3", len(got)) } if got[0].Direction != 0 || string(got[0].Data) != "ls\n" { t.Errorf("got[0] = %+v", got[0]) } if got[1].Direction != 1 || string(got[1].Data) != "file1\nfile2\n" { t.Errorf("got[1] = %+v", got[1]) } }) t.Run("append empty", func(t *testing.T) { store := newStore(t) if err := store.AppendSessionEvents(context.Background(), nil); err != nil { t.Fatalf("AppendSessionEvents(nil): %v", err) } }) }) } func TestCloseActiveSessions(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("no active sessions", func(t *testing.T) { store := newStore(t) ctx := context.Background() n, err := store.CloseActiveSessions(ctx, time.Now()) if err != nil { t.Fatalf("CloseActiveSessions: %v", err) } if n != 0 { t.Errorf("closed %d, want 0", n) } }) t.Run("closes only active sessions", func(t *testing.T) { store := newStore(t) ctx := context.Background() // Create 3 sessions: end one, leave two active. id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "") store.CreateSession(ctx, "10.0.0.3", "test", "bash", "") store.EndSession(ctx, id1, time.Now()) n, err := store.CloseActiveSessions(ctx, time.Now()) if err != nil { t.Fatalf("CloseActiveSessions: %v", err) } if n != 2 { t.Errorf("closed %d, want 2", n) } // Verify no active sessions remain. active, err := store.GetRecentSessions(ctx, 10, true) if err != nil { t.Fatalf("GetRecentSessions: %v", err) } if len(active) != 0 { t.Errorf("active sessions = %d, want 0", len(active)) } }) }) } func TestGetRecentSessions(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("empty", func(t *testing.T) { store := newStore(t) sessions, err := store.GetRecentSessions(context.Background(), 10, false) if err != nil { t.Fatalf("GetRecentSessions: %v", err) } if len(sessions) != 0 { t.Errorf("expected empty, got %d", len(sessions)) } }) t.Run("all sessions", func(t *testing.T) { store := newStore(t) seedData(t, store) sessions, err := store.GetRecentSessions(context.Background(), 10, false) if err != nil { t.Fatalf("GetRecentSessions: %v", err) } if len(sessions) != 2 { t.Fatalf("len = %d, want 2", len(sessions)) } }) t.Run("active only", func(t *testing.T) { store := newStore(t) seedData(t, store) sessions, err := store.GetRecentSessions(context.Background(), 10, true) if err != nil { t.Fatalf("GetRecentSessions: %v", err) } if len(sessions) != 1 { t.Fatalf("len = %d, want 1", len(sessions)) } if sessions[0].DisconnectedAt != nil { t.Error("active session should have nil DisconnectedAt") } }) t.Run("limit", func(t *testing.T) { store := newStore(t) seedData(t, store) sessions, err := store.GetRecentSessions(context.Background(), 1, false) if err != nil { t.Fatalf("GetRecentSessions: %v", err) } if len(sessions) != 1 { t.Fatalf("len = %d, want 1", len(sessions)) } }) }) }