package storage import ( "context" "path/filepath" "testing" "time" ) func newTestStore(t *testing.T) *SQLiteStore { t.Helper() dbPath := filepath.Join(t.TempDir(), "test.db") store, err := NewSQLiteStore(dbPath) if err != nil { t.Fatalf("creating store: %v", err) } t.Cleanup(func() { store.Close() }) return store } func TestRecordLoginAttempt(t *testing.T) { store := newTestStore(t) ctx := context.Background() // First attempt creates a new record. if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil { t.Fatalf("first attempt: %v", err) } // Second attempt with same credentials increments count. if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil { t.Fatalf("second attempt: %v", err) } // Different IP is a separate record. if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil { t.Fatalf("different IP: %v", err) } // Verify counts. var count int err := store.db.QueryRow(`SELECT count FROM login_attempts WHERE username = 'root' AND password = 'toor' AND ip = '10.0.0.1'`).Scan(&count) if err != nil { t.Fatalf("query: %v", err) } if count != 2 { t.Errorf("count = %d, want 2", count) } // Verify total rows. var total int err = store.db.QueryRow(`SELECT COUNT(*) FROM login_attempts`).Scan(&total) if err != nil { t.Fatalf("query total: %v", err) } if total != 2 { t.Errorf("total rows = %d, want 2", total) } } func TestCreateAndEndSession(t *testing.T) { store := newTestStore(t) ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "") if err != nil { t.Fatalf("creating session: %v", err) } if id == "" { t.Fatal("session ID is empty") } // Verify session exists. var username string err = store.db.QueryRow(`SELECT username FROM sessions WHERE id = ?`, id).Scan(&username) if err != nil { t.Fatalf("query session: %v", err) } if username != "root" { t.Errorf("username = %q, want %q", username, "root") } // End session. now := time.Now() if err := store.EndSession(ctx, id, now); err != nil { t.Fatalf("ending session: %v", err) } var disconnectedAt string err = store.db.QueryRow(`SELECT disconnected_at FROM sessions WHERE id = ?`, id).Scan(&disconnectedAt) if err != nil { t.Fatalf("query disconnected_at: %v", err) } if disconnectedAt == "" { t.Error("disconnected_at is empty after EndSession") } } func TestUpdateHumanScore(t *testing.T) { store := newTestStore(t) ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "") if err != nil { t.Fatalf("creating session: %v", err) } if err := store.UpdateHumanScore(ctx, id, 0.85); err != nil { t.Fatalf("updating score: %v", err) } var score float64 err = store.db.QueryRow(`SELECT human_score FROM sessions WHERE id = ?`, id).Scan(&score) if err != nil { t.Fatalf("query score: %v", err) } if score != 0.85 { t.Errorf("score = %f, want 0.85", score) } } func TestAppendSessionLog(t *testing.T) { store := newTestStore(t) ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "") if err != nil { t.Fatalf("creating session: %v", err) } if err := store.AppendSessionLog(ctx, id, "ls -la", ""); err != nil { t.Fatalf("append log: %v", err) } if err := store.AppendSessionLog(ctx, id, "", "total 4\ndrwxr-xr-x"); err != nil { t.Fatalf("append log output: %v", err) } var count int err = store.db.QueryRow(`SELECT COUNT(*) FROM session_logs WHERE session_id = ?`, id).Scan(&count) if err != nil { t.Fatalf("query logs: %v", err) } if count != 2 { t.Errorf("log count = %d, want 2", count) } } func TestDeleteRecordsBefore(t *testing.T) { store := newTestStore(t) ctx := context.Background() // Insert an old login attempt. oldTime := time.Now().AddDate(0, 0, -100).UTC().Format(time.RFC3339) _, err := store.db.Exec(` INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen) VALUES ('old', 'old', '1.1.1.1', 1, ?, ?)`, oldTime, oldTime) if err != nil { t.Fatalf("insert old attempt: %v", err) } // Insert a recent login attempt. if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil { t.Fatalf("insert recent attempt: %v", err) } // Insert an old session with a log entry. _, err = store.db.Exec(` INSERT INTO sessions (id, ip, username, shell_name, connected_at) VALUES ('old-session', '1.1.1.1', 'old', '', ?)`, oldTime) if err != nil { t.Fatalf("insert old session: %v", err) } _, err = store.db.Exec(` INSERT INTO session_logs (session_id, timestamp, input, output) VALUES ('old-session', ?, 'ls', '')`, oldTime) if err != nil { t.Fatalf("insert old log: %v", err) } // Insert a recent session. if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil { t.Fatalf("insert recent session: %v", err) } // Delete records older than 30 days. cutoff := time.Now().AddDate(0, 0, -30) deleted, err := store.DeleteRecordsBefore(ctx, cutoff) if err != nil { t.Fatalf("delete: %v", err) } if deleted != 3 { t.Errorf("deleted = %d, want 3 (1 attempt + 1 session + 1 log)", deleted) } // Verify recent records remain. var count int store.db.QueryRow(`SELECT COUNT(*) FROM login_attempts`).Scan(&count) if count != 1 { t.Errorf("remaining attempts = %d, want 1", count) } store.db.QueryRow(`SELECT COUNT(*) FROM sessions`).Scan(&count) if count != 1 { t.Errorf("remaining sessions = %d, want 1", count) } } func TestNewSQLiteStoreCreatesFile(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "test.db") store, err := NewSQLiteStore(dbPath) if err != nil { t.Fatalf("creating store: %v", err) } defer store.Close() // Verify we can use the store. ctx := context.Background() if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil { t.Fatalf("recording attempt: %v", err) } }