package storage import ( "context" "database/sql" "fmt" "time" "github.com/google/uuid" _ "modernc.org/sqlite" ) // SQLiteStore implements Store using a SQLite database. type SQLiteStore struct { db *sql.DB } // NewSQLiteStore opens or creates a SQLite database at the given path, // runs pending migrations, and returns a ready-to-use store. func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { dsn := dbPath + "?_pragma=journal_mode(wal)&_pragma=foreign_keys(on)&_pragma=busy_timeout(5000)" db, err := sql.Open("sqlite", dsn) if err != nil { return nil, fmt.Errorf("opening database: %w", err) } db.SetMaxOpenConns(1) if err := Migrate(db); err != nil { db.Close() return nil, fmt.Errorf("running migrations: %w", err) } return &SQLiteStore{db: db}, nil } func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error { now := time.Now().UTC().Format(time.RFC3339) _, err := s.db.ExecContext(ctx, ` INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen) VALUES (?, ?, ?, ?, 1, ?, ?) ON CONFLICT(username, password, ip) DO UPDATE SET count = count + 1, last_seen = ?, country = ?`, username, password, ip, country, now, now, now, country) if err != nil { return fmt.Errorf("recording login attempt: %w", err) } return nil } func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) { id := uuid.New().String() now := time.Now().UTC().Format(time.RFC3339) _, err := s.db.ExecContext(ctx, ` INSERT INTO sessions (id, ip, username, shell_name, country, connected_at) VALUES (?, ?, ?, ?, ?, ?)`, id, ip, username, shellName, country, now) if err != nil { return "", fmt.Errorf("creating session: %w", err) } return id, nil } func (s *SQLiteStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error { _, err := s.db.ExecContext(ctx, ` UPDATE sessions SET disconnected_at = ? WHERE id = ?`, disconnectedAt.UTC().Format(time.RFC3339), sessionID) if err != nil { return fmt.Errorf("ending session: %w", err) } return nil } func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error { _, err := s.db.ExecContext(ctx, ` UPDATE sessions SET human_score = ? WHERE id = ?`, score, sessionID) if err != nil { return fmt.Errorf("updating human score: %w", err) } return nil } func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error { _, err := s.db.ExecContext(ctx, ` UPDATE sessions SET exec_command = ? WHERE id = ?`, command, sessionID) if err != nil { return fmt.Errorf("setting exec command: %w", err) } return nil } func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error { now := time.Now().UTC().Format(time.RFC3339) _, err := s.db.ExecContext(ctx, ` INSERT INTO session_logs (session_id, timestamp, input, output) VALUES (?, ?, ?, ?)`, sessionID, now, input, output) if err != nil { return fmt.Errorf("appending session log: %w", err) } return nil } func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) { var sess Session var connectedAt string var disconnectedAt sql.NullString var humanScore sql.NullFloat64 var execCommand sql.NullString err := s.db.QueryRowContext(ctx, ` SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command FROM sessions WHERE id = ?`, sessionID).Scan( &sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("querying session: %w", err) } sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt) if disconnectedAt.Valid { t, _ := time.Parse(time.RFC3339, disconnectedAt.String) sess.DisconnectedAt = &t } if humanScore.Valid { sess.HumanScore = &humanScore.Float64 } if execCommand.Valid { sess.ExecCommand = &execCommand.String } return &sess, nil } func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) { rows, err := s.db.QueryContext(ctx, ` SELECT id, session_id, timestamp, input, output FROM session_logs WHERE session_id = ? ORDER BY timestamp`, sessionID) if err != nil { return nil, fmt.Errorf("querying session logs: %w", err) } defer func() { _ = rows.Close() }() var logs []SessionLog for rows.Next() { var l SessionLog var ts string if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil { return nil, fmt.Errorf("scanning session log: %w", err) } l.Timestamp, _ = time.Parse(time.RFC3339, ts) logs = append(logs, l) } return logs, rows.Err() } func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error { if len(events) == 0 { return nil } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin transaction: %w", err) } defer tx.Rollback() stmt, err := tx.PrepareContext(ctx, ` INSERT INTO session_events (session_id, timestamp, direction, data) VALUES (?, ?, ?, ?)`) if err != nil { return fmt.Errorf("preparing statement: %w", err) } defer stmt.Close() for _, e := range events { _, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data) if err != nil { return fmt.Errorf("inserting session event: %w", err) } } return tx.Commit() } func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) { rows, err := s.db.QueryContext(ctx, ` SELECT session_id, timestamp, direction, data FROM session_events WHERE session_id = ? ORDER BY id`, sessionID) if err != nil { return nil, fmt.Errorf("querying session events: %w", err) } defer func() { _ = rows.Close() }() var events []SessionEvent for rows.Next() { var e SessionEvent var ts string if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil { return nil, fmt.Errorf("scanning session event: %w", err) } e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts) events = append(events, e) } return events, rows.Err() } func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) { cutoffStr := cutoff.UTC().Format(time.RFC3339) tx, err := s.db.BeginTx(ctx, nil) if err != nil { return 0, fmt.Errorf("begin transaction: %w", err) } defer tx.Rollback() var total int64 // Delete session events for old sessions. res, err := tx.ExecContext(ctx, ` DELETE FROM session_events WHERE session_id IN ( SELECT id FROM sessions WHERE connected_at < ? )`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting session events: %w", err) } n, _ := res.RowsAffected() total += n // Delete session logs for old sessions. res, err = tx.ExecContext(ctx, ` DELETE FROM session_logs WHERE session_id IN ( SELECT id FROM sessions WHERE connected_at < ? )`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting session logs: %w", err) } n, _ = res.RowsAffected() total += n // Delete old sessions. res, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE connected_at < ?`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting sessions: %w", err) } n, _ = res.RowsAffected() total += n // Delete old login attempts. res, err = tx.ExecContext(ctx, `DELETE FROM login_attempts WHERE last_seen < ?`, cutoffStr) if err != nil { return 0, fmt.Errorf("deleting login attempts: %w", err) } n, _ = res.RowsAffected() total += n if err := tx.Commit(); err != nil { return 0, fmt.Errorf("commit transaction: %w", err) } return total, nil } func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { stats := &DashboardStats{} err := s.db.QueryRowContext(ctx, ` SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs) if err != nil { return nil, fmt.Errorf("querying attempt stats: %w", err) } err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions) if err != nil { return nil, fmt.Errorf("querying total sessions: %w", err) } err = s.db.QueryRowContext(ctx, ` SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions) if err != nil { return nil, fmt.Errorf("querying active sessions: %w", err) } return stats, nil } func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) { return s.queryTopN(ctx, "username", limit) } func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) { return s.queryTopN(ctx, "password", limit) } func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) { rows, err := s.db.QueryContext(ctx, ` SELECT ip, country, SUM(count) AS total FROM login_attempts GROUP BY ip ORDER BY total DESC LIMIT ?`, limit) if err != nil { return nil, fmt.Errorf("querying top IPs: %w", err) } defer func() { _ = rows.Close() }() var entries []TopEntry for rows.Next() { var e TopEntry if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil { return nil, fmt.Errorf("scanning top IPs: %w", err) } entries = append(entries, e) } return entries, rows.Err() } func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) { rows, err := s.db.QueryContext(ctx, ` SELECT country, SUM(count) AS total FROM login_attempts WHERE country != '' GROUP BY country ORDER BY total DESC LIMIT ?`, limit) if err != nil { return nil, fmt.Errorf("querying top countries: %w", err) } defer func() { _ = rows.Close() }() var entries []TopEntry for rows.Next() { var e TopEntry if err := rows.Scan(&e.Value, &e.Count); err != nil { return nil, fmt.Errorf("scanning top countries: %w", err) } entries = append(entries, e) } return entries, rows.Err() } func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) { switch column { case "username", "password", "ip": // valid columns default: return nil, fmt.Errorf("invalid column: %s", column) } query := fmt.Sprintf(` SELECT %s, SUM(count) AS total FROM login_attempts GROUP BY %s ORDER BY total DESC LIMIT ?`, column, column) rows, err := s.db.QueryContext(ctx, query, limit) if err != nil { return nil, fmt.Errorf("querying top %s: %w", column, err) } defer func() { _ = rows.Close() }() var entries []TopEntry for rows.Next() { var e TopEntry if err := rows.Scan(&e.Value, &e.Count); err != nil { return nil, fmt.Errorf("scanning top %s: %w", column, err) } entries = append(entries, e) } return entries, rows.Err() } func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) { query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` if activeOnly { query += ` WHERE s.disconnected_at IS NULL` } query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?` rows, err := s.db.QueryContext(ctx, query, limit) if err != nil { return nil, fmt.Errorf("querying recent sessions: %w", err) } defer func() { _ = rows.Close() }() var sessions []Session for rows.Next() { var s Session var connectedAt string var disconnectedAt sql.NullString var humanScore sql.NullFloat64 var execCommand sql.NullString if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &s.EventCount); err != nil { return nil, fmt.Errorf("scanning session: %w", err) } s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt) if disconnectedAt.Valid { t, _ := time.Parse(time.RFC3339, disconnectedAt.String) s.DisconnectedAt = &t } if humanScore.Valid { s.HumanScore = &humanScore.Float64 } if execCommand.Valid { s.ExecCommand = &execCommand.String } sessions = append(sessions, s) } return sessions, rows.Err() } func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) { rows, err := s.db.QueryContext(ctx, ` SELECT exec_command, COUNT(*) as total FROM sessions WHERE exec_command IS NOT NULL GROUP BY exec_command ORDER BY total DESC LIMIT ?`, limit) if err != nil { return nil, fmt.Errorf("querying top exec commands: %w", err) } defer func() { _ = rows.Close() }() var entries []TopEntry for rows.Next() { var e TopEntry if err := rows.Scan(&e.Value, &e.Count); err != nil { return nil, fmt.Errorf("scanning top exec commands: %w", err) } entries = append(entries, e) } return entries, rows.Err() } func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) { res, err := s.db.ExecContext(ctx, ` UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`, disconnectedAt.UTC().Format(time.RFC3339)) if err != nil { return 0, fmt.Errorf("closing active sessions: %w", err) } return res.RowsAffected() } func (s *SQLiteStore) Close() error { return s.db.Close() }