After an unclean shutdown, sessions could be left with disconnected_at NULL, appearing permanently active. Add CloseActiveSessions to the Store interface and call it at startup to close any leftover sessions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
307 lines
6.7 KiB
Go
307 lines
6.7 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// MemoryStore is an in-memory implementation of Store for use in tests.
|
|
type MemoryStore struct {
|
|
mu sync.Mutex
|
|
LoginAttempts []LoginAttempt
|
|
Sessions map[string]*Session
|
|
SessionLogs []SessionLog
|
|
SessionEvents []SessionEvent
|
|
}
|
|
|
|
// NewMemoryStore returns a new empty MemoryStore.
|
|
func NewMemoryStore() *MemoryStore {
|
|
return &MemoryStore{
|
|
Sessions: make(map[string]*Session),
|
|
}
|
|
}
|
|
|
|
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
now := time.Now().UTC()
|
|
for i := range m.LoginAttempts {
|
|
a := &m.LoginAttempts[i]
|
|
if a.Username == username && a.Password == password && a.IP == ip {
|
|
a.Count++
|
|
a.LastSeen = now
|
|
return nil
|
|
}
|
|
}
|
|
|
|
m.LoginAttempts = append(m.LoginAttempts, LoginAttempt{
|
|
ID: int64(len(m.LoginAttempts) + 1),
|
|
Username: username,
|
|
Password: password,
|
|
IP: ip,
|
|
Count: 1,
|
|
FirstSeen: now,
|
|
LastSeen: now,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
id := uuid.New().String()
|
|
now := time.Now().UTC()
|
|
m.Sessions[id] = &Session{
|
|
ID: id,
|
|
IP: ip,
|
|
Username: username,
|
|
ShellName: shellName,
|
|
ConnectedAt: now,
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (m *MemoryStore) EndSession(_ context.Context, sessionID string, disconnectedAt time.Time) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if s, ok := m.Sessions[sessionID]; ok {
|
|
t := disconnectedAt.UTC()
|
|
s.DisconnectedAt = &t
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, score float64) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if s, ok := m.Sessions[sessionID]; ok {
|
|
s.HumanScore = &score
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.SessionLogs = append(m.SessionLogs, SessionLog{
|
|
ID: int64(len(m.SessionLogs) + 1),
|
|
SessionID: sessionID,
|
|
Timestamp: time.Now().UTC(),
|
|
Input: input,
|
|
Output: output,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetSession(_ context.Context, sessionID string) (*Session, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
s, ok := m.Sessions[sessionID]
|
|
if !ok {
|
|
return nil, nil
|
|
}
|
|
copy := *s
|
|
return ©, nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetSessionLogs(_ context.Context, sessionID string) ([]SessionLog, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var logs []SessionLog
|
|
for _, l := range m.SessionLogs {
|
|
if l.SessionID == sessionID {
|
|
logs = append(logs, l)
|
|
}
|
|
}
|
|
sort.Slice(logs, func(i, j int) bool {
|
|
return logs[i].Timestamp.Before(logs[j].Timestamp)
|
|
})
|
|
return logs, nil
|
|
}
|
|
|
|
func (m *MemoryStore) AppendSessionEvents(_ context.Context, events []SessionEvent) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
m.SessionEvents = append(m.SessionEvents, events...)
|
|
return nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetSessionEvents(_ context.Context, sessionID string) ([]SessionEvent, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var events []SessionEvent
|
|
for _, e := range m.SessionEvents {
|
|
if e.SessionID == sessionID {
|
|
events = append(events, e)
|
|
}
|
|
}
|
|
return events, nil
|
|
}
|
|
|
|
func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var total int64
|
|
|
|
// Delete old login attempts.
|
|
kept := m.LoginAttempts[:0]
|
|
for _, a := range m.LoginAttempts {
|
|
if a.LastSeen.Before(cutoff) {
|
|
total++
|
|
} else {
|
|
kept = append(kept, a)
|
|
}
|
|
}
|
|
m.LoginAttempts = kept
|
|
|
|
// Delete old sessions and their logs.
|
|
for id, s := range m.Sessions {
|
|
if s.ConnectedAt.Before(cutoff) {
|
|
delete(m.Sessions, id)
|
|
total++
|
|
}
|
|
}
|
|
|
|
keptLogs := m.SessionLogs[:0]
|
|
for _, l := range m.SessionLogs {
|
|
if _, ok := m.Sessions[l.SessionID]; ok {
|
|
keptLogs = append(keptLogs, l)
|
|
} else {
|
|
total++
|
|
}
|
|
}
|
|
m.SessionLogs = keptLogs
|
|
|
|
keptEvents := m.SessionEvents[:0]
|
|
for _, e := range m.SessionEvents {
|
|
if _, ok := m.Sessions[e.SessionID]; ok {
|
|
keptEvents = append(keptEvents, e)
|
|
} else {
|
|
total++
|
|
}
|
|
}
|
|
m.SessionEvents = keptEvents
|
|
|
|
return total, nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetDashboardStats(_ context.Context) (*DashboardStats, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
stats := &DashboardStats{}
|
|
ips := make(map[string]struct{})
|
|
for _, a := range m.LoginAttempts {
|
|
stats.TotalAttempts += int64(a.Count)
|
|
ips[a.IP] = struct{}{}
|
|
}
|
|
stats.UniqueIPs = int64(len(ips))
|
|
stats.TotalSessions = int64(len(m.Sessions))
|
|
for _, s := range m.Sessions {
|
|
if s.DisconnectedAt == nil {
|
|
stats.ActiveSessions++
|
|
}
|
|
}
|
|
return stats, nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetTopUsernames(_ context.Context, limit int) ([]TopEntry, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.topN("username", limit), nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.topN("password", limit), nil
|
|
}
|
|
|
|
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.topN("ip", limit), nil
|
|
}
|
|
|
|
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
|
|
func (m *MemoryStore) topN(field string, limit int) []TopEntry {
|
|
counts := make(map[string]int64)
|
|
for _, a := range m.LoginAttempts {
|
|
var key string
|
|
switch field {
|
|
case "username":
|
|
key = a.Username
|
|
case "password":
|
|
key = a.Password
|
|
case "ip":
|
|
key = a.IP
|
|
}
|
|
counts[key] += int64(a.Count)
|
|
}
|
|
|
|
entries := make([]TopEntry, 0, len(counts))
|
|
for k, v := range counts {
|
|
entries = append(entries, TopEntry{Value: k, Count: v})
|
|
}
|
|
sort.Slice(entries, func(i, j int) bool {
|
|
return entries[i].Count > entries[j].Count
|
|
})
|
|
if limit > 0 && len(entries) > limit {
|
|
entries = entries[:limit]
|
|
}
|
|
return entries
|
|
}
|
|
|
|
func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly bool) ([]Session, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var sessions []Session
|
|
for _, s := range m.Sessions {
|
|
if activeOnly && s.DisconnectedAt != nil {
|
|
continue
|
|
}
|
|
sessions = append(sessions, *s)
|
|
}
|
|
sort.Slice(sessions, func(i, j int) bool {
|
|
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
|
})
|
|
if limit > 0 && len(sessions) > limit {
|
|
sessions = sessions[:limit]
|
|
}
|
|
return sessions, nil
|
|
}
|
|
|
|
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var count int64
|
|
t := disconnectedAt.UTC()
|
|
for _, s := range m.Sessions {
|
|
if s.DisconnectedAt == nil {
|
|
s.DisconnectedAt = &t
|
|
count++
|
|
}
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (m *MemoryStore) Close() error {
|
|
return nil
|
|
}
|