package storage import ( "context" "sort" "sync" "time" "github.com/google/uuid" ) // MemoryStore is an in-memory implementation of Store for use in tests. type MemoryStore struct { mu sync.Mutex LoginAttempts []LoginAttempt Sessions map[string]*Session SessionLogs []SessionLog SessionEvents []SessionEvent } // NewMemoryStore returns a new empty MemoryStore. func NewMemoryStore() *MemoryStore { return &MemoryStore{ Sessions: make(map[string]*Session), } } func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error { m.mu.Lock() defer m.mu.Unlock() now := time.Now().UTC() for i := range m.LoginAttempts { a := &m.LoginAttempts[i] if a.Username == username && a.Password == password && a.IP == ip { a.Count++ a.LastSeen = now a.Country = country return nil } } m.LoginAttempts = append(m.LoginAttempts, LoginAttempt{ ID: int64(len(m.LoginAttempts) + 1), Username: username, Password: password, IP: ip, Country: country, Count: 1, FirstSeen: now, LastSeen: now, }) return nil } func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) { m.mu.Lock() defer m.mu.Unlock() id := uuid.New().String() now := time.Now().UTC() m.Sessions[id] = &Session{ ID: id, IP: ip, Country: country, Username: username, ShellName: shellName, ConnectedAt: now, } return id, nil } func (m *MemoryStore) EndSession(_ context.Context, sessionID string, disconnectedAt time.Time) error { m.mu.Lock() defer m.mu.Unlock() if s, ok := m.Sessions[sessionID]; ok { t := disconnectedAt.UTC() s.DisconnectedAt = &t } return nil } func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, score float64) error { m.mu.Lock() defer m.mu.Unlock() if s, ok := m.Sessions[sessionID]; ok { s.HumanScore = &score } return nil } func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error { m.mu.Lock() defer m.mu.Unlock() if s, ok := m.Sessions[sessionID]; ok { s.ExecCommand = &command } return nil } func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error { m.mu.Lock() defer m.mu.Unlock() m.SessionLogs = append(m.SessionLogs, SessionLog{ ID: int64(len(m.SessionLogs) + 1), SessionID: sessionID, Timestamp: time.Now().UTC(), Input: input, Output: output, }) return nil } func (m *MemoryStore) GetSession(_ context.Context, sessionID string) (*Session, error) { m.mu.Lock() defer m.mu.Unlock() s, ok := m.Sessions[sessionID] if !ok { return nil, nil } copy := *s return ©, nil } func (m *MemoryStore) GetSessionLogs(_ context.Context, sessionID string) ([]SessionLog, error) { m.mu.Lock() defer m.mu.Unlock() var logs []SessionLog for _, l := range m.SessionLogs { if l.SessionID == sessionID { logs = append(logs, l) } } sort.Slice(logs, func(i, j int) bool { return logs[i].Timestamp.Before(logs[j].Timestamp) }) return logs, nil } func (m *MemoryStore) AppendSessionEvents(_ context.Context, events []SessionEvent) error { m.mu.Lock() defer m.mu.Unlock() m.SessionEvents = append(m.SessionEvents, events...) return nil } func (m *MemoryStore) GetSessionEvents(_ context.Context, sessionID string) ([]SessionEvent, error) { m.mu.Lock() defer m.mu.Unlock() var events []SessionEvent for _, e := range m.SessionEvents { if e.SessionID == sessionID { events = append(events, e) } } return events, nil } func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) { m.mu.Lock() defer m.mu.Unlock() var total int64 // Delete old login attempts. kept := m.LoginAttempts[:0] for _, a := range m.LoginAttempts { if a.LastSeen.Before(cutoff) { total++ } else { kept = append(kept, a) } } m.LoginAttempts = kept // Delete old sessions and their logs. for id, s := range m.Sessions { if s.ConnectedAt.Before(cutoff) { delete(m.Sessions, id) total++ } } keptLogs := m.SessionLogs[:0] for _, l := range m.SessionLogs { if _, ok := m.Sessions[l.SessionID]; ok { keptLogs = append(keptLogs, l) } else { total++ } } m.SessionLogs = keptLogs keptEvents := m.SessionEvents[:0] for _, e := range m.SessionEvents { if _, ok := m.Sessions[e.SessionID]; ok { keptEvents = append(keptEvents, e) } else { total++ } } m.SessionEvents = keptEvents return total, nil } func (m *MemoryStore) GetDashboardStats(_ context.Context) (*DashboardStats, error) { m.mu.Lock() defer m.mu.Unlock() stats := &DashboardStats{} ips := make(map[string]struct{}) for _, a := range m.LoginAttempts { stats.TotalAttempts += int64(a.Count) ips[a.IP] = struct{}{} } stats.UniqueIPs = int64(len(ips)) stats.TotalSessions = int64(len(m.Sessions)) for _, s := range m.Sessions { if s.DisconnectedAt == nil { stats.ActiveSessions++ } } return stats, nil } func (m *MemoryStore) GetTopUsernames(_ context.Context, limit int) ([]TopEntry, error) { m.mu.Lock() defer m.mu.Unlock() return m.topN("username", limit), nil } func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry, error) { m.mu.Lock() defer m.mu.Unlock() return m.topN("password", limit), nil } func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) { m.mu.Lock() defer m.mu.Unlock() type ipInfo struct { count int64 country string } agg := make(map[string]*ipInfo) for _, a := range m.LoginAttempts { info, ok := agg[a.IP] if !ok { info = &ipInfo{} agg[a.IP] = info } info.count += int64(a.Count) if a.Country != "" { info.country = a.Country } } entries := make([]TopEntry, 0, len(agg)) for ip, info := range agg { entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count}) } sort.Slice(entries, func(i, j int) bool { return entries[i].Count > entries[j].Count }) if limit > 0 && len(entries) > limit { entries = entries[:limit] } return entries, nil } func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) { m.mu.Lock() defer m.mu.Unlock() counts := make(map[string]int64) for _, a := range m.LoginAttempts { if a.Country == "" { continue } counts[a.Country] += int64(a.Count) } entries := make([]TopEntry, 0, len(counts)) for k, v := range counts { entries = append(entries, TopEntry{Value: k, Count: v}) } sort.Slice(entries, func(i, j int) bool { return entries[i].Count > entries[j].Count }) if limit > 0 && len(entries) > limit { entries = entries[:limit] } return entries, nil } // topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held. func (m *MemoryStore) topN(field string, limit int) []TopEntry { counts := make(map[string]int64) for _, a := range m.LoginAttempts { var key string switch field { case "username": key = a.Username case "password": key = a.Password case "ip": key = a.IP } counts[key] += int64(a.Count) } entries := make([]TopEntry, 0, len(counts)) for k, v := range counts { entries = append(entries, TopEntry{Value: k, Count: v}) } sort.Slice(entries, func(i, j int) bool { return entries[i].Count > entries[j].Count }) if limit > 0 && len(entries) > limit { entries = entries[:limit] } return entries } func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly bool) ([]Session, error) { m.mu.Lock() defer m.mu.Unlock() var sessions []Session for _, s := range m.Sessions { if activeOnly && s.DisconnectedAt != nil { continue } sessions = append(sessions, *s) } sort.Slice(sessions, func(i, j int) bool { return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt) }) if limit > 0 && len(sessions) > limit { sessions = sessions[:limit] } return sessions, nil } func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) { m.mu.Lock() defer m.mu.Unlock() var count int64 t := disconnectedAt.UTC() for _, s := range m.Sessions { if s.DisconnectedAt == nil { s.DisconnectedAt = &t count++ } } return count, nil } func (m *MemoryStore) Close() error { return nil }