Bots often send commands via `ssh user@host <command>` (exec request) rather than requesting an interactive shell. These were previously rejected silently. Now exec commands are captured, stored on the session record, and displayed in the web UI session detail page. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
436 lines
12 KiB
Go
436 lines
12 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
// SQLiteStore implements Store using a SQLite database.
|
|
type SQLiteStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewSQLiteStore opens or creates a SQLite database at the given path,
|
|
// runs pending migrations, and returns a ready-to-use store.
|
|
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
|
dsn := dbPath + "?_pragma=journal_mode(wal)&_pragma=foreign_keys(on)&_pragma=busy_timeout(5000)"
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("opening database: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(1)
|
|
|
|
if err := Migrate(db); err != nil {
|
|
db.Close()
|
|
return nil, fmt.Errorf("running migrations: %w", err)
|
|
}
|
|
|
|
return &SQLiteStore{db: db}, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
|
|
VALUES (?, ?, ?, ?, 1, ?, ?)
|
|
ON CONFLICT(username, password, ip) DO UPDATE SET
|
|
count = count + 1,
|
|
last_seen = ?,
|
|
country = ?`,
|
|
username, password, ip, country, now, now, now, country)
|
|
if err != nil {
|
|
return fmt.Errorf("recording login attempt: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
|
id := uuid.New().String()
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
id, ip, username, shellName, country, now)
|
|
if err != nil {
|
|
return "", fmt.Errorf("creating session: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET disconnected_at = ? WHERE id = ?`,
|
|
disconnectedAt.UTC().Format(time.RFC3339), sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("ending session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET human_score = ? WHERE id = ?`,
|
|
score, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("updating human score: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
|
_, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET exec_command = ? WHERE id = ?`,
|
|
command, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("setting exec command: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
_, err := s.db.ExecContext(ctx, `
|
|
INSERT INTO session_logs (session_id, timestamp, input, output)
|
|
VALUES (?, ?, ?, ?)`,
|
|
sessionID, now, input, output)
|
|
if err != nil {
|
|
return fmt.Errorf("appending session log: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
|
var sess Session
|
|
var connectedAt string
|
|
var disconnectedAt sql.NullString
|
|
var humanScore sql.NullFloat64
|
|
var execCommand sql.NullString
|
|
|
|
err := s.db.QueryRowContext(ctx, `
|
|
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
|
|
FROM sessions WHERE id = ?`, sessionID).Scan(
|
|
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
|
|
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
|
|
)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying session: %w", err)
|
|
}
|
|
|
|
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
|
if disconnectedAt.Valid {
|
|
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
|
sess.DisconnectedAt = &t
|
|
}
|
|
if humanScore.Valid {
|
|
sess.HumanScore = &humanScore.Float64
|
|
}
|
|
if execCommand.Valid {
|
|
sess.ExecCommand = &execCommand.String
|
|
}
|
|
return &sess, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, session_id, timestamp, input, output
|
|
FROM session_logs WHERE session_id = ?
|
|
ORDER BY timestamp`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying session logs: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var logs []SessionLog
|
|
for rows.Next() {
|
|
var l SessionLog
|
|
var ts string
|
|
if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil {
|
|
return nil, fmt.Errorf("scanning session log: %w", err)
|
|
}
|
|
l.Timestamp, _ = time.Parse(time.RFC3339, ts)
|
|
logs = append(logs, l)
|
|
}
|
|
return logs, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
|
if len(events) == 0 {
|
|
return nil
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
stmt, err := tx.PrepareContext(ctx, `
|
|
INSERT INTO session_events (session_id, timestamp, direction, data)
|
|
VALUES (?, ?, ?, ?)`)
|
|
if err != nil {
|
|
return fmt.Errorf("preparing statement: %w", err)
|
|
}
|
|
defer stmt.Close()
|
|
|
|
for _, e := range events {
|
|
_, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data)
|
|
if err != nil {
|
|
return fmt.Errorf("inserting session event: %w", err)
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT session_id, timestamp, direction, data
|
|
FROM session_events WHERE session_id = ?
|
|
ORDER BY id`, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying session events: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var events []SessionEvent
|
|
for rows.Next() {
|
|
var e SessionEvent
|
|
var ts string
|
|
if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil {
|
|
return nil, fmt.Errorf("scanning session event: %w", err)
|
|
}
|
|
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
|
|
events = append(events, e)
|
|
}
|
|
return events, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
|
cutoffStr := cutoff.UTC().Format(time.RFC3339)
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("begin transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
var total int64
|
|
|
|
// Delete session events for old sessions.
|
|
res, err := tx.ExecContext(ctx, `
|
|
DELETE FROM session_events WHERE session_id IN (
|
|
SELECT id FROM sessions WHERE connected_at < ?
|
|
)`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting session events: %w", err)
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
total += n
|
|
|
|
// Delete session logs for old sessions.
|
|
res, err = tx.ExecContext(ctx, `
|
|
DELETE FROM session_logs WHERE session_id IN (
|
|
SELECT id FROM sessions WHERE connected_at < ?
|
|
)`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting session logs: %w", err)
|
|
}
|
|
n, _ = res.RowsAffected()
|
|
total += n
|
|
|
|
// Delete old sessions.
|
|
res, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE connected_at < ?`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting sessions: %w", err)
|
|
}
|
|
n, _ = res.RowsAffected()
|
|
total += n
|
|
|
|
// Delete old login attempts.
|
|
res, err = tx.ExecContext(ctx, `DELETE FROM login_attempts WHERE last_seen < ?`, cutoffStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("deleting login attempts: %w", err)
|
|
}
|
|
n, _ = res.RowsAffected()
|
|
total += n
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return 0, fmt.Errorf("commit transaction: %w", err)
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
|
stats := &DashboardStats{}
|
|
|
|
err := s.db.QueryRowContext(ctx, `
|
|
SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip)
|
|
FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying attempt stats: %w", err)
|
|
}
|
|
|
|
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying total sessions: %w", err)
|
|
}
|
|
|
|
err = s.db.QueryRowContext(ctx, `
|
|
SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying active sessions: %w", err)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
return s.queryTopN(ctx, "username", limit)
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
return s.queryTopN(ctx, "password", limit)
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT ip, country, SUM(count) AS total
|
|
FROM login_attempts
|
|
GROUP BY ip
|
|
ORDER BY total DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top IPs: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top IPs: %w", err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT country, SUM(count) AS total
|
|
FROM login_attempts
|
|
WHERE country != ''
|
|
GROUP BY country
|
|
ORDER BY total DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top countries: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top countries: %w", err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
|
|
switch column {
|
|
case "username", "password", "ip":
|
|
// valid columns
|
|
default:
|
|
return nil, fmt.Errorf("invalid column: %s", column)
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT %s, SUM(count) AS total
|
|
FROM login_attempts
|
|
GROUP BY %s
|
|
ORDER BY total DESC
|
|
LIMIT ?`, column, column)
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying top %s: %w", column, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var entries []TopEntry
|
|
for rows.Next() {
|
|
var e TopEntry
|
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
|
return nil, fmt.Errorf("scanning top %s: %w", column, err)
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
|
query := `SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command FROM sessions`
|
|
if activeOnly {
|
|
query += ` WHERE disconnected_at IS NULL`
|
|
}
|
|
query += ` ORDER BY connected_at DESC LIMIT ?`
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, limit)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying recent sessions: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var sessions []Session
|
|
for rows.Next() {
|
|
var s Session
|
|
var connectedAt string
|
|
var disconnectedAt sql.NullString
|
|
var humanScore sql.NullFloat64
|
|
var execCommand sql.NullString
|
|
if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand); err != nil {
|
|
return nil, fmt.Errorf("scanning session: %w", err)
|
|
}
|
|
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
|
if disconnectedAt.Valid {
|
|
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
|
s.DisconnectedAt = &t
|
|
}
|
|
if humanScore.Valid {
|
|
s.HumanScore = &humanScore.Float64
|
|
}
|
|
if execCommand.Valid {
|
|
s.ExecCommand = &execCommand.String
|
|
}
|
|
sessions = append(sessions, s)
|
|
}
|
|
return sessions, rows.Err()
|
|
}
|
|
|
|
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
|
res, err := s.db.ExecContext(ctx, `
|
|
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
|
|
disconnectedAt.UTC().Format(time.RFC3339))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("closing active sessions: %w", err)
|
|
}
|
|
return res.RowsAffected()
|
|
}
|
|
|
|
func (s *SQLiteStore) Close() error {
|
|
return s.db.Close()
|
|
}
|