This repository has been archived on 2026-03-09. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
oubliette/internal/storage/sqlite.go
Torjus Håkestad 0133d956a5 feat: capture SSH exec commands (PLAN.md 4.4)
Bots often send commands via `ssh user@host <command>` (exec request)
rather than requesting an interactive shell. These were previously
rejected silently. Now exec commands are captured, stored on the session
record, and displayed in the web UI session detail page.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:43:11 +01:00

436 lines
12 KiB
Go

package storage
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
_ "modernc.org/sqlite"
)
// SQLiteStore implements Store using a SQLite database.
type SQLiteStore struct {
db *sql.DB
}
// NewSQLiteStore opens or creates a SQLite database at the given path,
// runs pending migrations, and returns a ready-to-use store.
func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
dsn := dbPath + "?_pragma=journal_mode(wal)&_pragma=foreign_keys(on)&_pragma=busy_timeout(5000)"
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}
db.SetMaxOpenConns(1)
if err := Migrate(db); err != nil {
db.Close()
return nil, fmt.Errorf("running migrations: %w", err)
}
return &SQLiteStore{db: db}, nil
}
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
VALUES (?, ?, ?, ?, 1, ?, ?)
ON CONFLICT(username, password, ip) DO UPDATE SET
count = count + 1,
last_seen = ?,
country = ?`,
username, password, ip, country, now, now, now, country)
if err != nil {
return fmt.Errorf("recording login attempt: %w", err)
}
return nil
}
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
id := uuid.New().String()
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
VALUES (?, ?, ?, ?, ?, ?)`,
id, ip, username, shellName, country, now)
if err != nil {
return "", fmt.Errorf("creating session: %w", err)
}
return id, nil
}
func (s *SQLiteStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
_, err := s.db.ExecContext(ctx, `
UPDATE sessions SET disconnected_at = ? WHERE id = ?`,
disconnectedAt.UTC().Format(time.RFC3339), sessionID)
if err != nil {
return fmt.Errorf("ending session: %w", err)
}
return nil
}
func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
_, err := s.db.ExecContext(ctx, `
UPDATE sessions SET human_score = ? WHERE id = ?`,
score, sessionID)
if err != nil {
return fmt.Errorf("updating human score: %w", err)
}
return nil
}
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
_, err := s.db.ExecContext(ctx, `
UPDATE sessions SET exec_command = ? WHERE id = ?`,
command, sessionID)
if err != nil {
return fmt.Errorf("setting exec command: %w", err)
}
return nil
}
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO session_logs (session_id, timestamp, input, output)
VALUES (?, ?, ?, ?)`,
sessionID, now, input, output)
if err != nil {
return fmt.Errorf("appending session log: %w", err)
}
return nil
}
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
var sess Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
err := s.db.QueryRowContext(ctx, `
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
FROM sessions WHERE id = ?`, sessionID).Scan(
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying session: %w", err)
}
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
sess.DisconnectedAt = &t
}
if humanScore.Valid {
sess.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
return &sess, nil
}
func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, session_id, timestamp, input, output
FROM session_logs WHERE session_id = ?
ORDER BY timestamp`, sessionID)
if err != nil {
return nil, fmt.Errorf("querying session logs: %w", err)
}
defer func() { _ = rows.Close() }()
var logs []SessionLog
for rows.Next() {
var l SessionLog
var ts string
if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil {
return nil, fmt.Errorf("scanning session log: %w", err)
}
l.Timestamp, _ = time.Parse(time.RFC3339, ts)
logs = append(logs, l)
}
return logs, rows.Err()
}
func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
if len(events) == 0 {
return nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback()
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO session_events (session_id, timestamp, direction, data)
VALUES (?, ?, ?, ?)`)
if err != nil {
return fmt.Errorf("preparing statement: %w", err)
}
defer stmt.Close()
for _, e := range events {
_, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data)
if err != nil {
return fmt.Errorf("inserting session event: %w", err)
}
}
return tx.Commit()
}
func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT session_id, timestamp, direction, data
FROM session_events WHERE session_id = ?
ORDER BY id`, sessionID)
if err != nil {
return nil, fmt.Errorf("querying session events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []SessionEvent
for rows.Next() {
var e SessionEvent
var ts string
if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil {
return nil, fmt.Errorf("scanning session event: %w", err)
}
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
events = append(events, e)
}
return events, rows.Err()
}
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
cutoffStr := cutoff.UTC().Format(time.RFC3339)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return 0, fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback()
var total int64
// Delete session events for old sessions.
res, err := tx.ExecContext(ctx, `
DELETE FROM session_events WHERE session_id IN (
SELECT id FROM sessions WHERE connected_at < ?
)`, cutoffStr)
if err != nil {
return 0, fmt.Errorf("deleting session events: %w", err)
}
n, _ := res.RowsAffected()
total += n
// Delete session logs for old sessions.
res, err = tx.ExecContext(ctx, `
DELETE FROM session_logs WHERE session_id IN (
SELECT id FROM sessions WHERE connected_at < ?
)`, cutoffStr)
if err != nil {
return 0, fmt.Errorf("deleting session logs: %w", err)
}
n, _ = res.RowsAffected()
total += n
// Delete old sessions.
res, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE connected_at < ?`, cutoffStr)
if err != nil {
return 0, fmt.Errorf("deleting sessions: %w", err)
}
n, _ = res.RowsAffected()
total += n
// Delete old login attempts.
res, err = tx.ExecContext(ctx, `DELETE FROM login_attempts WHERE last_seen < ?`, cutoffStr)
if err != nil {
return 0, fmt.Errorf("deleting login attempts: %w", err)
}
n, _ = res.RowsAffected()
total += n
if err := tx.Commit(); err != nil {
return 0, fmt.Errorf("commit transaction: %w", err)
}
return total, nil
}
func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
stats := &DashboardStats{}
err := s.db.QueryRowContext(ctx, `
SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip)
FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs)
if err != nil {
return nil, fmt.Errorf("querying attempt stats: %w", err)
}
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions)
if err != nil {
return nil, fmt.Errorf("querying total sessions: %w", err)
}
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions)
if err != nil {
return nil, fmt.Errorf("querying active sessions: %w", err)
}
return stats, nil
}
func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
return s.queryTopN(ctx, "username", limit)
}
func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
return s.queryTopN(ctx, "password", limit)
}
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT ip, country, SUM(count) AS total
FROM login_attempts
GROUP BY ip
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
switch column {
case "username", "password", "ip":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts
GROUP BY %s
ORDER BY total DESC
LIMIT ?`, column, column)
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, fmt.Errorf("querying top %s: %w", column, err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top %s: %w", column, err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
query := `SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command FROM sessions`
if activeOnly {
query += ` WHERE disconnected_at IS NULL`
}
query += ` ORDER BY connected_at DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, fmt.Errorf("querying recent sessions: %w", err)
}
defer func() { _ = rows.Close() }()
var sessions []Session
for rows.Next() {
var s Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand); err != nil {
return nil, fmt.Errorf("scanning session: %w", err)
}
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
s.DisconnectedAt = &t
}
if humanScore.Valid {
s.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
s.ExecCommand = &execCommand.String
}
sessions = append(sessions, s)
}
return sessions, rows.Err()
}
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx, `
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
disconnectedAt.UTC().Format(time.RFC3339))
if err != nil {
return 0, fmt.Errorf("closing active sessions: %w", err)
}
return res.RowsAffected()
}
func (s *SQLiteStore) Close() error {
return s.db.Close()
}