package storage import ( "context" "database/sql" "fmt" "time" "github.com/google/uuid" _ "modernc.org/sqlite" ) // SQLiteStore implements Store using a SQLite database. type SQLiteStore struct { db *sql.DB } // NewSQLiteStore opens or creates a SQLite database at the given path, // runs pending migrations, and returns a ready-to-use store. func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { dsn := dbPath + "?_pragma=journal_mode(wal)&_pragma=foreign_keys(on)&_pragma=busy_timeout(5000)" db, err := sql.Open("sqlite", dsn) if err != nil { return nil, fmt.Errorf("opening database: %w", err) } db.SetMaxOpenConns(1) if err := Migrate(db); err != nil { db.Close() return nil, fmt.Errorf("running migrations: %w", err) } return &SQLiteStore{db: db}, nil } func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error { now := time.Now().UTC().Format(time.RFC3339) _, err := s.db.ExecContext(ctx, ` INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen) VALUES (?, ?, ?, 1, ?, ?) ON CONFLICT(username, password, ip) DO UPDATE SET count = count + 1, last_seen = ?`, username, password, ip, now, now, now) if err != nil { return fmt.Errorf("recording login attempt: %w", err) } return nil } func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) { id := uuid.New().String() now := time.Now().UTC().Format(time.RFC3339) _, err := s.db.ExecContext(ctx, ` INSERT INTO sessions (id, ip, username, shell_name, connected_at) VALUES (?, ?, ?, ?, ?)`, id, ip, username, shellName, now) if err != nil { return "", fmt.Errorf("creating session: %w", err) } return id, nil } func (s *SQLiteStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error { _, err := s.db.ExecContext(ctx, ` UPDATE sessions SET disconnected_at = ? WHERE id = ?`, disconnectedAt.UTC().Format(time.RFC3339), sessionID) if err != nil { return fmt.Errorf("ending session: %w", err) } return nil } func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error { _, err := s.db.ExecContext(ctx, ` UPDATE sessions SET human_score = ? WHERE id = ?`, score, sessionID) if err != nil { return fmt.Errorf("updating human score: %w", err) } return nil } func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error { now := time.Now().UTC().Format(time.RFC3339) _, err := s.db.ExecContext(ctx, ` INSERT INTO session_logs (session_id, timestamp, input, output) VALUES (?, ?, ?, ?)`, sessionID, now, input, output) if err != nil { return fmt.Errorf("appending session log: %w", err) } return nil } func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) { cutoffStr := cutoff.UTC().Format(time.RFC3339) tx, err := s.db.BeginTx(ctx, nil) if err != nil { return 0, fmt.Errorf("begin transaction: %w", err) } defer tx.Rollback() var total int64 // Delete session logs for old sessions. res, err := tx.ExecContext(ctx, ` DELETE FROM session_logs WHERE session_id IN ( SELECT id FROM sessions WHERE connected_at < ? )`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting session logs: %w", err) } n, _ := res.RowsAffected() total += n // Delete old sessions. res, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE connected_at < ?`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting sessions: %w", err) } n, _ = res.RowsAffected() total += n // Delete old login attempts. res, err = tx.ExecContext(ctx, `DELETE FROM login_attempts WHERE last_seen < ?`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting login attempts: %w", err) } n, _ = res.RowsAffected() total += n if err := tx.Commit(); err != nil { return 0, fmt.Errorf("commit transaction: %w", err) } return total, nil } func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { stats := &DashboardStats{} err := s.db.QueryRowContext(ctx, ` SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs) if err != nil { return nil, fmt.Errorf("querying attempt stats: %w", err) } err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions) if err != nil { return nil, fmt.Errorf("querying total sessions: %w", err) } err = s.db.QueryRowContext(ctx, ` SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions) if err != nil { return nil, fmt.Errorf("querying active sessions: %w", err) } return stats, nil } func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) { return s.queryTopN(ctx, "username", limit) } func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) { return s.queryTopN(ctx, "password", limit) } func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) { return s.queryTopN(ctx, "ip", limit) } func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) { query := fmt.Sprintf(` SELECT %s, SUM(count) AS total FROM login_attempts GROUP BY %s ORDER BY total DESC LIMIT ?`, column, column) rows, err := s.db.QueryContext(ctx, query, limit) if err != nil { return nil, fmt.Errorf("querying top %s: %w", column, err) } defer func() { _ = rows.Close() }() var entries []TopEntry for rows.Next() { var e TopEntry if err := rows.Scan(&e.Value, &e.Count); err != nil { return nil, fmt.Errorf("scanning top %s: %w", column, err) } entries = append(entries, e) } return entries, rows.Err() } func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) { query := `SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score FROM sessions` if activeOnly { query += ` WHERE disconnected_at IS NULL` } query += ` ORDER BY connected_at DESC LIMIT ?` rows, err := s.db.QueryContext(ctx, query, limit) if err != nil { return nil, fmt.Errorf("querying recent sessions: %w", err) } defer func() { _ = rows.Close() }() var sessions []Session for rows.Next() { var s Session var connectedAt string var disconnectedAt sql.NullString var humanScore sql.NullFloat64 if err := rows.Scan(&s.ID, &s.IP, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore); err != nil { return nil, fmt.Errorf("scanning session: %w", err) } s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt) if disconnectedAt.Valid { t, _ := time.Parse(time.RFC3339, disconnectedAt.String) s.DisconnectedAt = &t } if humanScore.Valid { s.HumanScore = &humanScore.Float64 } sessions = append(sessions, s) } return sessions, rows.Err() } func (s *SQLiteStore) Close() error { return s.db.Close() }