Add visual indicators to session tables (replay badge when events exist, exec badge for exec sessions) and a new "Top Exec Commands" table on the dashboard. Includes EventCount field on Session, GetTopExecCommands on Store interface, and truncateCommand template function. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
460 lines
13 KiB
Go
460 lines
13 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
// SQLiteStore implements Store using a SQLite database.
|
|
type SQLiteStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewSQLiteStore opens or creates a SQLite database at the given path,
|
|
// runs pending migrations, and returns a ready-to-use store.
|
|
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
|
dsn := dbPath + "?_pragma=journal_mode(wal)&_pragma=foreign_keys(on)&_pragma=busy_timeout(5000)"
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("opening database: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(1)
|
|
|
|
if err := Migrate(db); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("running migrations: %w", err)
|
|
}
|
|
|
|
return &SQLiteStore{db: db}, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, 1, ?, ?)
|
|
ON CONFLICT(username, password, ip) DO UPDATE SET
|
|
count = count + 1,
|
|
last_seen = ?,
|
|
country = ?`,
|
|
username, password, ip, country, now, now, now, country)
|
|
if err != nil {
|
|
return fmt.Errorf("recording login attempt: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
|
id := uuid.New().String()
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
id, ip, username, shellName, country, now)
|
|
if err != nil {
|
|
return "", fmt.Errorf("creating session: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET disconnected_at = ? WHERE id = ?`,
|
|
disconnectedAt.UTC().Format(time.RFC3339), sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("ending session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET human_score = ? WHERE id = ?`,
|
|
score, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("updating human score: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET exec_command = ? WHERE id = ?`,
|
|
command, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("setting exec command: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO session_logs (session_id, timestamp, input, output)
|
|
VALUES (?, ?, ?, ?)`,
|
|
sessionID, now, input, output)
|
|
if err != nil {
|
|
return fmt.Errorf("appending session log: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
|
var sess Session
|
|
var connectedAt string
|
|
var disconnectedAt sql.NullString
|
|
var humanScore sql.NullFloat64
|
|
var execCommand sql.NullString
|
|
|
|
err := s.db.QueryRowContext(ctx, `
|
|
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
|
|
FROM sessions WHERE id = ?`, sessionID).Scan(
|
|
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
|
|
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
|
|
)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying session: %w", err)
|
|
}
|
|
|
|
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
|
if disconnectedAt.Valid {
|
|
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
|
sess.DisconnectedAt = &t
|
|
}
|
|
if humanScore.Valid {
|
|
sess.HumanScore = &humanScore.Float64
|
|
}
|
|
if execCommand.Valid {
|
|
sess.ExecCommand = &execCommand.String
|
|
}
|
|
return &sess, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, session_id, timestamp, input, output
|
|
FROM session_logs WHERE session_id = ?
|
|
ORDER BY timestamp`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying session logs: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var logs []SessionLog
|
|
for rows.Next() {
|
|
var l SessionLog
|
|
var ts string
|
|
if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil {
|
|
return nil, fmt.Errorf("scanning session log: %w", err)
|
|
}
|
|
l.Timestamp, _ = time.Parse(time.RFC3339, ts)
|
|
logs = append(logs, l)
|
|
}
|
|
return logs, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
|
if len(events) == 0 {
|
|
return nil
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
stmt, err := tx.PrepareContext(ctx, `
|
|
INSERT INTO session_events (session_id, timestamp, direction, data)
|
|
VALUES (?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return fmt.Errorf("preparing statement: %w", err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
for _, e := range events {
|
|
_, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data)
|
|
if err != nil {
|
|
return fmt.Errorf("inserting session event: %w", err)
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT session_id, timestamp, direction, data
|
|
FROM session_events WHERE session_id = ?
|
|
ORDER BY id`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying session events: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var events []SessionEvent
|
|
for rows.Next() {
|
|
var e SessionEvent
|
|
var ts string
|
|
if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil {
|
|
return nil, fmt.Errorf("scanning session event: %w", err)
|
|
}
|
|
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
|
|
events = append(events, e)
|
|
}
|
|
return events, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
|
cutoffStr := cutoff.UTC().Format(time.RFC3339)
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var total int64
|
|
|
|
// Delete session events for old sessions.
|
|
res, err := tx.ExecContext(ctx, `
|
|
DELETE FROM session_events WHERE session_id IN (
|
|
SELECT id FROM sessions WHERE connected_at < ?
|
|
)`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting session events: %w", err)
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
total += n
|
|
|
|
// Delete session logs for old sessions.
|
|
res, err = tx.ExecContext(ctx, `
|
|
DELETE FROM session_logs WHERE session_id IN (
|
|
SELECT id FROM sessions WHERE connected_at < ?
|
|
)`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting session logs: %w", err)
|
|
}
|
|
n, _ = res.RowsAffected()
|
|
total += n
|
|
|
|
// Delete old sessions.
|
|
res, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE connected_at < ?`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting sessions: %w", err)
|
|
}
|
|
n, _ = res.RowsAffected()
|
|
total += n
|
|
|
|
// Delete old login attempts.
|
|
res, err = tx.ExecContext(ctx, `DELETE FROM login_attempts WHERE last_seen < ?`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting login attempts: %w", err)
|
|
}
|
|
n, _ = res.RowsAffected()
|
|
total += n
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return 0, fmt.Errorf("commit transaction: %w", err)
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
|
stats := &DashboardStats{}
|
|
|
|
err := s.db.QueryRowContext(ctx, `
|
|
SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip)
|
|
FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying attempt stats: %w", err)
|
|
}
|
|
|
|
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying total sessions: %w", err)
|
|
}
|
|
|
|
err = s.db.QueryRowContext(ctx, `
|
|
SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying active sessions: %w", err)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
return s.queryTopN(ctx, "username", limit)
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
return s.queryTopN(ctx, "password", limit)
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT ip, country, SUM(count) AS total
|
|
FROM login_attempts
|
|
GROUP BY ip
|
|
ORDER BY total DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top IPs: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top IPs: %w", err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT country, SUM(count) AS total
|
|
FROM login_attempts
|
|
WHERE country != ''
|
|
GROUP BY country
|
|
ORDER BY total DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top countries: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top countries: %w", err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
|
|
switch column {
|
|
case "username", "password", "ip":
|
|
// valid columns
|
|
default:
|
|
return nil, fmt.Errorf("invalid column: %s", column)
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT %s, SUM(count) AS total
|
|
FROM login_attempts
|
|
GROUP BY %s
|
|
ORDER BY total DESC
|
|
LIMIT ?`, column, column)
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top %s: %w", column, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top %s: %w", column, err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
|
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
|
if activeOnly {
|
|
query += ` WHERE s.disconnected_at IS NULL`
|
|
}
|
|
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying recent sessions: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var sessions []Session
|
|
for rows.Next() {
|
|
var s Session
|
|
var connectedAt string
|
|
var disconnectedAt sql.NullString
|
|
var humanScore sql.NullFloat64
|
|
var execCommand sql.NullString
|
|
if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &s.EventCount); err != nil {
|
|
return nil, fmt.Errorf("scanning session: %w", err)
|
|
}
|
|
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
|
if disconnectedAt.Valid {
|
|
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
|
s.DisconnectedAt = &t
|
|
}
|
|
if humanScore.Valid {
|
|
s.HumanScore = &humanScore.Float64
|
|
}
|
|
if execCommand.Valid {
|
|
s.ExecCommand = &execCommand.String
|
|
}
|
|
sessions = append(sessions, s)
|
|
}
|
|
return sessions, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT exec_command, COUNT(*) as total
|
|
FROM sessions
|
|
WHERE exec_command IS NOT NULL
|
|
GROUP BY exec_command
|
|
ORDER BY total DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top exec commands: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top exec commands: %w", err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
|
res, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
|
|
disconnectedAt.UTC().Format(time.RFC3339))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("closing active sessions: %w", err)
|
|
}
|
|
return res.RowsAffected()
|
|
}
|
|
|
|
func (s *SQLiteStore) Close() error {
|
|
return s.db.Close()
|
|
}
|