Compare commits

..

17 Commits

Author SHA1 Message Date
1b28f10ca8 refactor: migrate module path from git.t-juice.club to code.t-juice.club
Update Go module path and all import references to reflect the migration
from Gitea (git.t-juice.club) to Forgejo (code.t-juice.club).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:51:23 +01:00
664e79fce6 feat: add Prometheus metrics for Store queries
Add InstrumentedStore decorator that wraps any Store and records
per-method query duration histograms and error counters. Wired into
main.go so all storage consumers get automatic observability.

Bump version to 0.18.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:29:51 +01:00
c74313c195 fix: resolve linting issues in roomba shell
Replace unnecessary fmt.Sprintf calls with string literals, use
slices.Contains instead of manual loop, and use compound assignment
operator.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:18:25 +01:00
9783ae5865 fix: prevent context canceled errors in web dashboard
Detach DB queries from HTTP request context so HTMX polling doesn't
cancel in-flight queries when the browser aborts previous XHRs. Add
indexes on login_attempts and sessions to speed up frequent dashboard
queries. Bump version to 0.17.1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:16:49 +01:00
62de222488 feat: add tetris shell (Tetris game TUI)
Full-screen Tetris game using Bubbletea with title screen, ghost piece,
lock delay, NES-style scoring, configurable difficulty (easy/normal/hard),
and honeypot event logging. Bumps version to 0.17.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 00:59:46 +01:00
c9d143d84b feat: add roomba shell (iRobot Roomba j7+ vacuum emulator)
New novelty shell emulating RoombaOS with cleaning, scheduling,
diagnostics, floor map, and humorous history entries. Bump version
to 0.16.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 14:06:59 +01:00
d18a904ed5 chore: bump version to 0.15.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:13:50 +01:00
cb7be28f42 feat: add server-side session filtering with input bytes and human score
Replace client-side session table filtering with server-side filtering
via a new /fragments/recent-sessions htmx endpoint. Add InputBytes column
to session tables, Human score > 0 checkbox filter, and Sort by Input
Bytes option to help identify sessions with actual shell interaction.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:12:51 +01:00
0908b43724 chore: bump version to 0.14.2
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
52310f588d fix: highlight all polygons on hover for multi-path countries
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
b52216bd2f chore: bump version to 0.14.1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:21:01 +01:00
2bc83a17dd fix: handle SVG group elements in world map for multi-path countries
The SVG world map uses <g> group elements for countries with complex
shapes (US, CN, RU, GB, etc.), but the JS only queried <path> elements,
causing 36 countries to be missing from the map. Also removes the SVG
<title> element that was overriding the custom tooltip.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:20:23 +01:00
faf6e2abd7 docs: mark 4.1 and 4.4 as completed in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:34:17 +01:00
0a4eac188a chore: bump version to 0.14.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:31:53 +01:00
7c90c9ed4a feat: add charts, world map, and filters to web dashboard
Add Chart.js line/bar charts for attack trends (attempts over time,
hourly pattern), an SVG world map choropleth colored by attack origin
country, and a collapsible filter form (date range, IP, country,
username) that narrows both charts and top-N tables.

New store methods: GetAttemptsOverTime, GetHourlyPattern, GetCountryStats,
and filtered variants of dashboard stats/top-N queries. New JSON API
endpoints at /api/charts/* and an htmx fragment at
/fragments/dashboard-content for filtered table updates.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:27:15 +01:00
8a631af0d2 fix: prevent dashboard top-grid cards from overflowing horizontally
Increase minimum column width from 280px to 380px so the 3-column Top
IPs table fits without clipping. Add overflow/min-width safety net for
narrow viewports.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 21:25:20 +01:00
40fda3420c feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash
meta-commands, SQL statement handling with multi-line buffering, and
canned responses for common queries. Add username-based shell routing
via [shell.username_routes] config (second priority after credential-
specific shell, before random selection). Bump version to 0.13.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:58:34 +01:00
58 changed files with 5368 additions and 143 deletions

37
PLAN.md
View File

@@ -171,7 +171,20 @@ Goal: Add the entertaining shell implementations.
### 3.5 Banking TUI Shell ✅
- 80s-style green-on-black bank terminal
### 3.6 Other Shell Ideas (Future)
### 3.6 PostgreSQL psql Shell ✅
- Simulates psql interactive terminal with `db_name` and `pg_version` config
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
- SQL statement handling with multi-line buffering (semicolon-terminated)
- Canned responses for common queries (SELECT version(), current_database(), etc.)
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
### 3.7 Roomba Shell ✅
- iRobot Roomba j7+ vacuum robot interface
- Status, cleaning, scheduling, diagnostics, floor map
- Humorous history entries (cat encounters, sock tangles, sticky substances)
### 3.8 Other Shell Ideas (Future)
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
- **ELIZA therapist:** every response is a therapy question
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
@@ -183,11 +196,11 @@ Goal: Add the entertaining shell implementations.
Goal: Make the web UI great and add operational niceties.
### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log
- Filtering and search
### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log
- Filtering and search
### 4.2 Operational ✅
- Prometheus metrics endpoint ✅
@@ -201,15 +214,15 @@ Goal: Make the web UI great and add operational niceties.
- Store country/city with each attempt ✅
- Aggregate stats by country ✅
### 4.4 Capture SSH Exec Commands
### 4.4 Capture SSH Exec Commands
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`)
- Parse the command string from the exec payload
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration
- Store the command on the session record before closing the channel
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`)
- Parse the command string from the exec payload
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration
- Store the command on the session record before closing the channel
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
- Surface exec commands in the web UI (session detail view)
- Surface exec commands in the web UI (session detail view)
#### 4.4.1 Fake Exec Output
Return plausible fake output for exec commands to encourage bots to interact further.

View File

@@ -34,7 +34,8 @@ Key settings:
- `auth.accept_after` — accept login after N failures per IP (default `10`)
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
- `storage.db_path` — SQLite database path (default `oubliette.db`)
- `storage.retention_days` — auto-prune records older than N days (default `90`)
- `storage.retention_interval` — how often to run retention (default `1h`)
@@ -43,6 +44,7 @@ Key settings:
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
- `web.enabled` — enable the web dashboard (default `false`)
- `web.listen_addr` — web dashboard listen address (default `:8080`)
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)

View File

@@ -13,14 +13,14 @@ import (
"syscall"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/server"
"git.t-juice.club/torjus/oubliette/internal/storage"
"git.t-juice.club/torjus/oubliette/internal/web"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/server"
"code.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/web"
)
const Version = "0.12.0"
const Version = "0.18.0"
func main() {
if err := run(); err != nil {
@@ -76,12 +76,13 @@ func run() error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
m := metrics.New(Version)
m.RegisterStoreCollector(store)
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
m.RegisterStoreCollector(instrumentedStore)
srv, err := server.New(*cfg, store, logger, m)
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
srv, err := server.New(*cfg, instrumentedStore, logger, m)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
@@ -95,7 +96,7 @@ func run() error {
metricsHandler = m.Handler()
}
webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
if err != nil {
return fmt.Errorf("create web server: %w", err)
}

4
go.mod
View File

@@ -1,4 +1,4 @@
module git.t-juice.club/torjus/oubliette
module code.t-juice.club/torjus/oubliette
go 1.25.5
@@ -9,6 +9,7 @@ require (
github.com/google/uuid v1.6.0
github.com/oschwald/maxminddb-golang v1.13.1
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_model v0.6.2
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.45.0
)
@@ -33,7 +34,6 @@ require (
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect

View File

@@ -5,7 +5,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
const (

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {

View File

@@ -28,10 +28,11 @@ type WebConfig struct {
}
type ShellConfig struct {
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
UsernameRoutes map[string]string `toml:"username_routes"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
}
type StorageConfig struct {
@@ -165,9 +166,10 @@ func applyDefaults(cfg *Config) {
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
var knownShellKeys = map[string]bool{
"hostname": true,
"banner": true,
"fake_user": true,
"hostname": true,
"banner": true,
"fake_user": true,
"username_routes": true,
}
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.

View File

@@ -313,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
}
}
func TestLoadUsernameRoutes(t *testing.T) {
content := `
[shell]
hostname = "myhost"
[shell.username_routes]
postgres = "psql"
admin = "bash"
[shell.bash]
custom_key = "value"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.UsernameRoutes == nil {
t.Fatal("UsernameRoutes should not be nil")
}
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
}
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
}
// username_routes should NOT appear in the Shells map.
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
t.Error("username_routes should not appear in Shells map")
}
// bash should still appear in Shells map.
if _, ok := cfg.Shell.Shells["bash"]; !ok {
t.Error("Shells[\"bash\"] should still be present")
}
}
func writeTemp(t *testing.T, content string) string {
t.Helper()
path := filepath.Join(t.TempDir(), "config.toml")

View File

@@ -4,7 +4,7 @@ import (
"context"
"net/http"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -25,6 +25,8 @@ type Metrics struct {
SessionDuration prometheus.Histogram
ExecCommandsTotal prometheus.Counter
BuildInfo *prometheus.GaugeVec
StorageQueryDuration *prometheus.HistogramVec
StorageQueryErrors *prometheus.CounterVec
}
// New creates a new Metrics instance with all collectors registered.
@@ -79,6 +81,15 @@ func New(version string) *Metrics {
Name: "oubliette_build_info",
Help: "Build information. Always 1.",
}, []string{"version"}),
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "oubliette_storage_query_duration_seconds",
Help: "Duration of storage query calls in seconds.",
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"method"}),
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_storage_query_errors_total",
Help: "Total storage query errors.",
}, []string{"method"}),
}
reg.MustRegister(
@@ -95,6 +106,8 @@ func New(version string) *Metrics {
m.SessionDuration,
m.ExecCommandsTotal,
m.BuildInfo,
m.StorageQueryDuration,
m.StorageQueryErrors,
)
m.BuildInfo.WithLabelValues(version).Set(1)

View File

@@ -8,7 +8,7 @@ import (
"strings"
"testing"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestNew(t *testing.T) {

View File

@@ -10,7 +10,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
// Event types.

View File

@@ -10,7 +10,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func testSession() SessionInfo {

View File

@@ -12,19 +12,22 @@ import (
"os"
"time"
"git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/detection"
"git.t-juice.club/torjus/oubliette/internal/geoip"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/notify"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/adventure"
"git.t-juice.club/torjus/oubliette/internal/shell/banking"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/auth"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/detection"
"code.t-juice.club/torjus/oubliette/internal/geoip"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/notify"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -58,6 +61,15 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
return nil, fmt.Errorf("registering cisco shell: %w", err)
}
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
return nil, fmt.Errorf("registering psql shell: %w", err)
}
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
return nil, fmt.Errorf("registering roomba shell: %w", err)
}
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
return nil, fmt.Errorf("registering tetris shell: %w", err)
}
geo, err := geoip.New()
if err != nil {
@@ -185,6 +197,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
}
}
// Second priority: username-based route.
if selectedShell == nil {
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
sh, found := s.shellRegistry.Get(shellName)
if found {
selectedShell = sh
} else {
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
}
}
}
// Lowest priority: random selection.
if selectedShell == nil {
var err error
selectedShell, err = s.shellRegistry.Select()

View File

@@ -11,9 +11,10 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/auth"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -300,6 +301,89 @@ func TestIntegrationSSHConnect(t *testing.T) {
}
})
// Test username route: add username_routes so that "postgres" gets psql shell.
t.Run("username_route", func(t *testing.T) {
// Reconfigure with username routes.
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
// Need to get the "postgres" user in via static creds or threshold.
// Use static creds for simplicity.
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
config.Credential{Username: "postgres", Password: "postgres"},
)
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
defer func() {
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
}()
clientCfg := &ssh.ClientConfig{
User: "postgres",
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the psql banner.
time.Sleep(500 * time.Millisecond)
// Send \q to quit.
stdin.Write([]byte(`\q` + "\r"))
time.Sleep(200 * time.Millisecond)
session.Wait()
out := output.String()
if !strings.Contains(out, "psql") {
t.Errorf("output should contain psql banner, got: %s", out)
}
// Verify session was created with shell name "psql".
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundPsql bool
for _, s := range sessions {
if s.ShellName == "psql" && s.Username == "postgres" {
foundPsql = true
break
}
}
if !foundPsql {
t.Error("expected a session with shell_name='psql' for user 'postgres'")
}
})
// Test threshold acceptance: after enough failed dials, a subsequent
// dial with the same credentials should succeed via threshold or
// remembered credential.

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute

View File

@@ -8,8 +8,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {

View File

@@ -9,7 +9,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute

View File

@@ -9,8 +9,8 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.

View File

@@ -8,7 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute

View File

@@ -6,7 +6,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// EventRecorder buffers I/O events in memory and periodically flushes them to

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestEventRecorderFlush(t *testing.T) {

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute

View File

@@ -8,8 +8,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {

View File

@@ -0,0 +1,123 @@
package psql
import (
"fmt"
"strings"
"time"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// dispatchBackslash handles psql backslash meta-commands.
func dispatchBackslash(cmd, dbName string) commandResult {
// Normalize: trim spaces after the backslash command word.
parts := strings.Fields(cmd)
if len(parts) == 0 {
return commandResult{output: "Invalid command \\. Try \\? for help."}
}
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
args := parts[1:]
switch verb {
case `\q`:
return commandResult{exit: true}
case `\dt`:
return commandResult{output: listTables()}
case `\d`:
if len(args) == 0 {
return commandResult{output: listTables()}
}
return commandResult{output: describeTable(args[0])}
case `\l`:
return commandResult{output: listDatabases()}
case `\du`:
return commandResult{output: listRoles()}
case `\conninfo`:
return commandResult{output: connInfo(dbName)}
case `\?`:
return commandResult{output: backslashHelp()}
case `\h`:
return commandResult{output: sqlHelp()}
default:
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
}
}
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
// Strip trailing semicolon and whitespace for matching.
trimmed := strings.TrimRight(sql, "; \t")
trimmed = strings.TrimSpace(trimmed)
upper := strings.ToUpper(trimmed)
switch {
case upper == "SELECT VERSION()":
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
return commandResult{output: formatSingleValue("version", ver)}
case upper == "SELECT CURRENT_DATABASE()":
return commandResult{output: formatSingleValue("current_database", dbName)}
case upper == "SELECT CURRENT_USER":
return commandResult{output: formatSingleValue("current_user", "postgres")}
case upper == "SELECT NOW()":
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
return commandResult{output: formatSingleValue("now", now)}
case upper == "SELECT 1":
return commandResult{output: formatSingleValue("?column?", "1")}
case strings.HasPrefix(upper, "INSERT"):
return commandResult{output: "INSERT 0 1"}
case strings.HasPrefix(upper, "UPDATE"):
return commandResult{output: "UPDATE 1"}
case strings.HasPrefix(upper, "DELETE"):
return commandResult{output: "DELETE 1"}
case strings.HasPrefix(upper, "CREATE TABLE"):
return commandResult{output: "CREATE TABLE"}
case strings.HasPrefix(upper, "CREATE DATABASE"):
return commandResult{output: "CREATE DATABASE"}
case strings.HasPrefix(upper, "DROP TABLE"):
return commandResult{output: "DROP TABLE"}
case strings.HasPrefix(upper, "ALTER TABLE"):
return commandResult{output: "ALTER TABLE"}
case upper == "BEGIN":
return commandResult{output: "BEGIN"}
case upper == "COMMIT":
return commandResult{output: "COMMIT"}
case upper == "ROLLBACK":
return commandResult{output: "ROLLBACK"}
case upper == "SHOW SERVER_VERSION":
return commandResult{output: formatSingleValue("server_version", pgVersion)}
case upper == "SHOW SEARCH_PATH":
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
case strings.HasPrefix(upper, "SET "):
return commandResult{output: "SET"}
default:
// Extract the first token for the error message.
firstToken := strings.Fields(trimmed)
token := trimmed
if len(firstToken) > 0 {
token = firstToken[0]
}
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
}
}
// formatSingleValue formats a single-row, single-column psql result.
func formatSingleValue(colName, value string) string {
width := max(len(colName), len(value))
var b strings.Builder
// Header
fmt.Fprintf(&b, " %-*s \n", width, colName)
// Separator
b.WriteString(strings.Repeat("-", width+2))
b.WriteString("\n")
// Value
fmt.Fprintf(&b, " %-*s\n", width, value)
// Row count
b.WriteString("(1 row)")
return b.String()
}

View File

@@ -0,0 +1,155 @@
package psql
import "fmt"
func startupBanner(version string) string {
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
}
func listTables() string {
return ` List of relations
Schema | Name | Type | Owner
--------+---------------+-------+----------
public | audit_log | table | postgres
public | credentials | table | postgres
public | sessions | table | postgres
public | users | table | postgres
(4 rows)`
}
func listDatabases() string {
return ` List of databases
Name | Owner | Encoding | Collate | Ctype | Access privileges
-----------+----------+----------+-------------+-------------+-----------------------
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
(4 rows)`
}
func listRoles() string {
return ` List of roles
Role name | Attributes | Member of
-----------+------------------------------------------------------------+-----------
app_user | | {}
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
readonly | Cannot login | {}`
}
func describeTable(name string) string {
switch name {
case "users":
return ` Table "public.users"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+-----------------------------------
id | integer | | not null | nextval('users_id_seq'::regclass)
username | character varying(255) | | not null |
email | character varying(255) | | not null |
password | character varying(255) | | not null |
created_at | timestamp without time zone | | | now()
updated_at | timestamp without time zone | | | now()
Indexes:
"users_pkey" PRIMARY KEY, btree (id)
"users_email_key" UNIQUE, btree (email)
"users_username_key" UNIQUE, btree (username)`
case "sessions":
return ` Table "public.sessions"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+--------------------------------------
id | integer | | not null | nextval('sessions_id_seq'::regclass)
user_id | integer | | |
token | character varying(255) | | not null |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
expires_at | timestamp without time zone | | not null |
Indexes:
"sessions_pkey" PRIMARY KEY, btree (id)
"sessions_token_key" UNIQUE, btree (token)
Foreign-key constraints:
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "credentials":
return ` Table "public.credentials"
Column | Type | Collation | Nullable | Default
-----------+-----------------------------+-----------+----------+-----------------------------------------
id | integer | | not null | nextval('credentials_id_seq'::regclass)
user_id | integer | | |
type | character varying(50) | | not null |
value | text | | not null |
created_at| timestamp without time zone | | | now()
Indexes:
"credentials_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "audit_log":
return ` Table "public.audit_log"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+---------------------------------------
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
user_id | integer | | |
action | character varying(100) | | not null |
details | text | | |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
Indexes:
"audit_log_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
default:
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
}
}
func connInfo(dbName string) string {
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
}
func backslashHelp() string {
return `General
\copyright show PostgreSQL usage and distribution terms
\crosstabview [COLUMNS] execute query and display result in crosstab
\errverbose show most recent error message at maximum verbosity
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
\gdesc describe result of query, without executing it
\gexec execute query, then execute each value in its result
\gset [PREFIX] execute query and store result in psql variables
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
\q quit psql
\watch [SEC] execute query every SEC seconds
Informational
(options: S = show system objects, + = additional detail)
\d[S+] list tables, views, and sequences
\d[S+] NAME describe table, view, sequence, or index
\da[S] [PATTERN] list aggregates
\dA[+] [PATTERN] list access methods
\dt[S+] [PATTERN] list tables
\du[S+] [PATTERN] list roles
\l[+] [PATTERN] list databases`
}
func sqlHelp() string {
return `Available help:
ABORT CREATE LANGUAGE
ALTER AGGREGATE CREATE MATERIALIZED VIEW
ALTER COLLATION CREATE OPERATOR
ALTER CONVERSION CREATE POLICY
ALTER DATABASE CREATE PROCEDURE
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
ALTER DOMAIN CREATE ROLE
ALTER EVENT TRIGGER CREATE RULE
ALTER EXTENSION CREATE SCHEMA
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
ALTER FOREIGN TABLE CREATE SERVER
ALTER FUNCTION CREATE STATISTICS
ALTER GROUP CREATE SUBSCRIPTION
ALTER INDEX CREATE TABLE
ALTER LANGUAGE CREATE TABLESPACE
BEGIN DELETE
COMMIT DROP TABLE
CREATE DATABASE INSERT
CREATE INDEX ROLLBACK
SELECT UPDATE`
}

137
internal/shell/psql/psql.go Normal file
View File

@@ -0,0 +1,137 @@
package psql
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// PsqlShell emulates a PostgreSQL psql interactive terminal.
type PsqlShell struct{}
// NewPsqlShell returns a new PsqlShell instance.
func NewPsqlShell() *PsqlShell {
return &PsqlShell{}
}
func (p *PsqlShell) Name() string { return "psql" }
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
dbName := configString(sess.ShellConfig, "db_name", "postgres")
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
// Print startup banner.
fmt.Fprint(rw, startupBanner(pgVersion))
var sqlBuf []string // accumulates multi-line SQL
for {
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
// Empty line in non-buffering state: just re-prompt.
if trimmed == "" && len(sqlBuf) == 0 {
continue
}
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
if strings.HasPrefix(trimmed, `\`) {
sqlBuf = nil // discard any partial SQL
result := dispatchBackslash(trimmed, dbName)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
continue
}
// Accumulate SQL lines.
sqlBuf = append(sqlBuf, line)
// Check if the statement is terminated by a semicolon.
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
continue
}
// Full statement ready — join and dispatch.
fullSQL := strings.Join(sqlBuf, " ")
sqlBuf = nil
result := dispatchSQL(fullSQL, dbName, pgVersion)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
}
}
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
func buildPrompt(dbName string, continuation bool) string {
if continuation {
return dbName + "-# "
}
return dbName + "=# "
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,330 @@
package psql
import (
"strings"
"testing"
)
// --- Prompt tests ---
func TestBuildPromptNormal(t *testing.T) {
got := buildPrompt("postgres", false)
if got != "postgres=# " {
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
}
}
func TestBuildPromptContinuation(t *testing.T) {
got := buildPrompt("postgres", true)
if got != "postgres-# " {
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
}
}
func TestBuildPromptCustomDB(t *testing.T) {
got := buildPrompt("mydb", false)
if got != "mydb=# " {
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
}
}
// --- Backslash command dispatch tests ---
func TestBackslashQuit(t *testing.T) {
result := dispatchBackslash(`\q`, "postgres")
if !result.exit {
t.Error("\\q should set exit=true")
}
}
func TestBackslashListTables(t *testing.T) {
result := dispatchBackslash(`\dt`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\dt should list tables including 'users'")
}
if !strings.Contains(result.output, "sessions") {
t.Error("\\dt should list tables including 'sessions'")
}
}
func TestBackslashDescribeTable(t *testing.T) {
result := dispatchBackslash(`\d users`, "postgres")
if !strings.Contains(result.output, "username") {
t.Error("\\d users should describe users table with 'username' column")
}
if !strings.Contains(result.output, "PRIMARY KEY") {
t.Error("\\d users should include index info")
}
}
func TestBackslashDescribeUnknownTable(t *testing.T) {
result := dispatchBackslash(`\d nonexistent`, "postgres")
if !strings.Contains(result.output, "Did not find") {
t.Error("\\d nonexistent should return not found message")
}
}
func TestBackslashListDatabases(t *testing.T) {
result := dispatchBackslash(`\l`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\l should list databases including 'postgres'")
}
if !strings.Contains(result.output, "template0") {
t.Error("\\l should list databases including 'template0'")
}
}
func TestBackslashListRoles(t *testing.T) {
result := dispatchBackslash(`\du`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\du should list roles including 'postgres'")
}
if !strings.Contains(result.output, "Superuser") {
t.Error("\\du should show Superuser attribute for postgres")
}
}
func TestBackslashConnInfo(t *testing.T) {
result := dispatchBackslash(`\conninfo`, "mydb")
if !strings.Contains(result.output, "mydb") {
t.Error("\\conninfo should include database name")
}
if !strings.Contains(result.output, "5432") {
t.Error("\\conninfo should include port")
}
}
func TestBackslashHelp(t *testing.T) {
result := dispatchBackslash(`\?`, "postgres")
if !strings.Contains(result.output, `\q`) {
t.Error("\\? should include \\q in help output")
}
}
func TestBackslashSQLHelp(t *testing.T) {
result := dispatchBackslash(`\h`, "postgres")
if !strings.Contains(result.output, "SELECT") {
t.Error("\\h should include SQL commands like SELECT")
}
}
func TestBackslashUnknown(t *testing.T) {
result := dispatchBackslash(`\xyz`, "postgres")
if !strings.Contains(result.output, "Invalid command") {
t.Error("unknown backslash command should return error")
}
}
// --- SQL dispatch tests ---
func TestSQLSelectVersion(t *testing.T) {
result := dispatchSQL("SELECT version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SELECT version() should contain pg version")
}
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT version() should show row count")
}
}
func TestSQLSelectCurrentDatabase(t *testing.T) {
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("SELECT current_database() should return db name")
}
}
func TestSQLSelectCurrentUser(t *testing.T) {
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
if !strings.Contains(result.output, "postgres") {
t.Error("SELECT current_user should return postgres")
}
}
func TestSQLSelectNow(t *testing.T) {
result := dispatchSQL("SELECT now();", "postgres", "15.4")
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT now() should show row count")
}
}
func TestSQLSelectOne(t *testing.T) {
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
if !strings.Contains(result.output, "1") {
t.Error("SELECT 1 should return 1")
}
}
func TestSQLInsert(t *testing.T) {
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
if result.output != "INSERT 0 1" {
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
}
}
func TestSQLUpdate(t *testing.T) {
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
if result.output != "UPDATE 1" {
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
}
}
func TestSQLDelete(t *testing.T) {
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
if result.output != "DELETE 1" {
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
}
}
func TestSQLCreateTable(t *testing.T) {
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
if result.output != "CREATE TABLE" {
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
}
}
func TestSQLCreateDatabase(t *testing.T) {
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
if result.output != "CREATE DATABASE" {
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
}
}
func TestSQLDropTable(t *testing.T) {
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
if result.output != "DROP TABLE" {
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
}
}
func TestSQLAlterTable(t *testing.T) {
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
if result.output != "ALTER TABLE" {
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
}
}
func TestSQLBeginCommitRollback(t *testing.T) {
tests := []struct {
sql string
want string
}{
{"BEGIN;", "BEGIN"},
{"COMMIT;", "COMMIT"},
{"ROLLBACK;", "ROLLBACK"},
}
for _, tt := range tests {
result := dispatchSQL(tt.sql, "postgres", "15.4")
if result.output != tt.want {
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
}
}
}
func TestSQLShowServerVersion(t *testing.T) {
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SHOW server_version should contain version")
}
}
func TestSQLShowSearchPath(t *testing.T) {
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
if !strings.Contains(result.output, "public") {
t.Error("SHOW search_path should contain public")
}
}
func TestSQLSet(t *testing.T) {
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
if result.output != "SET" {
t.Errorf("SET output = %q, want %q", result.output, "SET")
}
}
func TestSQLUnrecognized(t *testing.T) {
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
if !strings.Contains(result.output, "ERROR") {
t.Error("unrecognized SQL should return error")
}
if !strings.Contains(result.output, "FOOBAR") {
t.Error("error should reference the offending token")
}
}
// --- Case insensitivity ---
func TestSQLCaseInsensitive(t *testing.T) {
result := dispatchSQL("select version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("select version() (lowercase) should work")
}
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("mixed case SELECT should work")
}
}
// --- Startup banner ---
func TestStartupBanner(t *testing.T) {
banner := startupBanner("15.4")
if !strings.Contains(banner, "psql (15.4)") {
t.Errorf("banner should contain version, got: %s", banner)
}
if !strings.Contains(banner, "help") {
t.Error("banner should mention help")
}
}
// --- configString ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"db_name": "mydb"}
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
t.Errorf("configString() = %q, want %q", got, "mydb")
}
if got := configString(cfg, "missing", "default"); got != "default" {
t.Errorf("configString() for missing = %q, want %q", got, "default")
}
if got := configString(nil, "key", "default"); got != "default" {
t.Errorf("configString(nil) = %q, want %q", got, "default")
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewPsqlShell()
if s.Name() != "psql" {
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
// --- formatSingleValue ---
func TestFormatSingleValue(t *testing.T) {
out := formatSingleValue("?column?", "1")
if !strings.Contains(out, "?column?") {
t.Error("should contain column name")
}
if !strings.Contains(out, "1") {
t.Error("should contain value")
}
if !strings.Contains(out, "(1 row)") {
t.Error("should contain row count")
}
}
// --- \d with no args ---
func TestBackslashDescribeNoArgs(t *testing.T) {
result := dispatchBackslash(`\d`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\d with no args should list tables")
}
}

View File

@@ -0,0 +1,463 @@
package roomba
import (
"context"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
type RoombaShell struct{}
// NewRoombaShell returns a new RoombaShell instance.
func NewRoombaShell() *RoombaShell {
return &RoombaShell{}
}
func (r *RoombaShell) Name() string { return "roomba" }
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
state := newRoombaState()
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
fmt.Fprint(rw, banner)
for {
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("roomba")
}
if result.exit {
return nil
}
}
}
func bootBanner() string {
return `
____ _ ___ ____
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
iRobot Roomba j7+ | RoombaOS v4.3.7
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
Battery: 73% | WiFi: Connected (SmartHome-5G)
Type 'help' for available commands.
`
}
type room struct {
name string
areaSqFt int
lastCleaned time.Time
}
type scheduleEntry struct {
day string
time string
}
type historyEntry struct {
timestamp time.Time
room string
duration string
note string
}
type roombaState struct {
battery int
dustbin int
status string
rooms []room
schedule []scheduleEntry
cleanHistory []historyEntry
}
type commandResult struct {
output string
exit bool
}
func newRoombaState() *roombaState {
now := time.Now()
return &roombaState{
battery: 73,
dustbin: 61,
status: "Docked",
rooms: []room{
{"Kitchen", 180, now.Add(-2 * time.Hour)},
{"Living Room", 320, now.Add(-5 * time.Hour)},
{"Bedroom", 200, now.Add(-26 * time.Hour)},
{"Hallway", 60, now.Add(-5 * time.Hour)},
{"Bathroom", 75, now.Add(-50 * time.Hour)},
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
},
schedule: []scheduleEntry{
{"Monday", "09:00"},
{"Wednesday", "09:00"},
{"Friday", "09:00"},
{"Saturday", "14:00"},
},
cleanHistory: []historyEntry{
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
},
}
}
func (s *roombaState) dispatch(input string) commandResult {
parts := strings.Fields(input)
if len(parts) == 0 {
return commandResult{}
}
cmd := strings.ToLower(parts[0])
args := parts[1:]
switch cmd {
case "help":
return s.cmdHelp()
case "status":
return s.cmdStatus()
case "clean":
return s.cmdClean(args)
case "dock":
return s.cmdDock()
case "map":
return s.cmdMap()
case "schedule":
return s.cmdSchedule(args)
case "history":
return s.cmdHistory()
case "diagnostics":
return s.cmdDiagnostics()
case "alerts":
return s.cmdAlerts()
case "reboot":
return s.cmdReboot()
case "exit", "logout":
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
default:
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
}
}
func (s *roombaState) cmdHelp() commandResult {
help := `Available commands:
help - Show this help message
status - Show robot status
clean - Start full cleaning job
clean room <name> - Clean a specific room
dock - Return to dock
map - Show floor plan and room list
schedule - List cleaning schedule
schedule add <day> <time> - Add scheduled cleaning
schedule remove <day> - Remove scheduled cleaning
history - Show recent cleaning history
diagnostics - Run system diagnostics
alerts - Show active alerts
reboot - Reboot RoombaOS
exit / logout - Disconnect`
return commandResult{output: help}
}
func (s *roombaState) cmdStatus() commandResult {
var b strings.Builder
b.WriteString("=== RoombaOS System Status ===\n")
b.WriteString("Model: iRobot Roomba j7+\n")
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
b.WriteString("Side brush: OK (142 hrs)\n")
b.WriteString("Main brush: OK (98 hrs)\n")
b.WriteString("\n")
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
b.WriteString("Signal: -38 dBm\n")
b.WriteString("Alexa: Linked\n")
b.WriteString("Google Home: Linked\n")
b.WriteString("iRobot Home App: Connected\n")
b.WriteString("\n")
b.WriteString("Firmware: v4.3.7-stable\n")
b.WriteString("LIDAR: Operational\n")
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
return commandResult{output: b.String()}
}
func (s *roombaState) cmdClean(args []string) commandResult {
if s.status == "Cleaning" {
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
}
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
roomName := strings.Join(args[1:], " ")
for _, r := range s.rooms {
if strings.EqualFold(r.name, roomName) {
s.status = "Cleaning"
return commandResult{output: fmt.Sprintf(
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
)}
}
}
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
}
if len(args) > 0 {
return commandResult{output: "Usage: clean [room <name>]"}
}
s.status = "Cleaning"
var totalArea int
for _, r := range s.rooms {
totalArea += r.areaSqFt
}
return commandResult{output: fmt.Sprintf(
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
totalArea, len(s.rooms), totalArea/8,
)}
}
func (s *roombaState) cmdDock() commandResult {
if s.status == "Docked" {
return commandResult{output: "Already docked."}
}
if s.status == "Returning to dock" {
return commandResult{output: "Already returning to dock."}
}
s.status = "Returning to dock"
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
}
func (s *roombaState) cmdMap() commandResult {
var b strings.Builder
b.WriteString("=== Floor Plan ===\n\n")
b.WriteString(" +------------+----------+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Kitchen | Bathroom |\n")
b.WriteString(" | 180sqft | 75sqft |\n")
b.WriteString(" | | |\n")
b.WriteString(" +------+-----+----+-----+\n")
b.WriteString(" | | | |\n")
b.WriteString(" | Hall | Living | Cat |\n")
b.WriteString(" | 60sf | Room | Rm |\n")
b.WriteString(" | | 320sqft |110sf|\n")
b.WriteString(" +------+ +-----+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Bed +----------+\n")
b.WriteString(" | room | [DOCK]\n")
b.WriteString(" |200sf |\n")
b.WriteString(" +------+\n")
b.WriteString("\nRoom Details:\n")
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
for _, r := range s.rooms {
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
}
return commandResult{output: b.String()}
}
func (s *roombaState) cmdSchedule(args []string) commandResult {
if len(args) == 0 {
return s.scheduleList()
}
sub := strings.ToLower(args[0])
switch sub {
case "add":
if len(args) < 3 {
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
}
return s.scheduleAdd(args[1], args[2])
case "remove":
if len(args) < 2 {
return commandResult{output: "Usage: schedule remove <day>"}
}
return s.scheduleRemove(args[1])
default:
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
}
}
func (s *roombaState) scheduleList() commandResult {
if len(s.schedule) == 0 {
return commandResult{output: "No cleaning schedule configured."}
}
var b strings.Builder
b.WriteString("=== Cleaning Schedule ===\n")
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
for _, e := range s.schedule {
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
}
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
return commandResult{output: b.String()}
}
func (s *roombaState) scheduleAdd(day, t string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
if !slices.Contains(validDays, day) {
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
}
for _, e := range s.schedule {
if strings.EqualFold(e.day, day) {
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
}
}
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
}
func (s *roombaState) scheduleRemove(day string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
for i, e := range s.schedule {
if strings.EqualFold(e.day, day) {
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
}
}
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
}
func (s *roombaState) cmdHistory() commandResult {
if len(s.cleanHistory) == 0 {
return commandResult{output: "No cleaning history."}
}
var b strings.Builder
b.WriteString("=== Cleaning History ===\n")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
for _, h := range s.cleanHistory {
ts := h.timestamp.Format("2006-01-02 15:04")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
}
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdDiagnostics() commandResult {
diag := `Running RoombaOS diagnostics...
[1/8] Cliff sensors........... OK
[2/8] Bumper sensor........... OK
[3/8] Side brush motor........ OK (142 hrs until replacement)
[4/8] Main brush motor........ OK (98 hrs until replacement)
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
[7/8] Dustbin sensor.......... OK
[8/8] WiFi module............. OK (signal: -38 dBm)
ALL SYSTEMS NOMINAL`
return commandResult{output: diag}
}
func (s *roombaState) cmdAlerts() commandResult {
var alerts []string
if s.dustbin >= 60 {
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
}
alerts = append(alerts,
"WARNING: Side brush replacement due in 12 hours",
"INFO: Unidentified sticky substance detected in Kitchen",
"INFO: Cat frequently blocking cleaning path in Cat's Room",
"INFO: Firmware update available: v4.4.0-beta",
"INFO: Filter replacement recommended in 14 days",
)
var b strings.Builder
b.WriteString("=== Active Alerts ===\n")
for _, a := range alerts {
b.WriteString(a + "\n")
}
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdReboot() commandResult {
reboot := `RoombaOS is rebooting...
Stopping navigation engine..... done
Saving room map data........... done
Flushing cleaning logs......... done
Disconnecting from WiFi........ done
Rebooting now. Goodbye!`
return commandResult{output: reboot, exit: true}
}
func capitalizeFirst(s string) string {
if s == "" {
return s
}
return strings.ToUpper(s[:1]) + s[1:]
}
func formatDuration(d time.Duration) string {
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
if hours >= 24 {
days := hours / 24
hours %= 24
return fmt.Sprintf("%dd %dh", days, hours)
}
if hours > 0 {
return fmt.Sprintf("%dh %dm", hours, minutes)
}
return fmt.Sprintf("%dm", minutes)
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"io"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// Shell is the interface that all honeypot shell implementations must satisfy.

View File

@@ -0,0 +1,101 @@
package tetris
import "github.com/charmbracelet/lipgloss"
// pieceType identifies a tetromino (06).
type pieceType int
const (
pieceI pieceType = iota
pieceO
pieceT
pieceS
pieceZ
pieceJ
pieceL
)
const numPieceTypes = 7
// Standard Tetris colors.
var pieceColors = [numPieceTypes]lipgloss.Color{
lipgloss.Color("#00FFFF"), // I — cyan
lipgloss.Color("#FFFF00"), // O — yellow
lipgloss.Color("#AA00FF"), // T — purple
lipgloss.Color("#00FF00"), // S — green
lipgloss.Color("#FF0000"), // Z — red
lipgloss.Color("#0000FF"), // J — blue
lipgloss.Color("#FF8800"), // L — orange
}
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
// relative to the piece origin.
type rotation [4][2]int
var pieces = [numPieceTypes][4]rotation{
// I
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
},
// O
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
},
// T
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// S
{
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// Z
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
},
// J
{
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
},
// L
{
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
},
}
// spawnCol returns the starting column for a piece, centering it on the board.
func spawnCol(pt pieceType, rot int) int {
shape := pieces[pt][rot]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
width := maxC - minC + 1
return (boardCols - width) / 2
}

View File

@@ -0,0 +1,210 @@
package tetris
import "math/rand/v2"
const (
boardRows = 20
boardCols = 10
)
// cell represents a single board cell. Zero value is empty.
type cell struct {
filled bool
piece pieceType // which piece type filled this cell (for color)
}
// gameState holds all mutable state for a Tetris game.
type gameState struct {
board [boardRows][boardCols]cell
current pieceType
currentRot int
currentRow int
currentCol int
next pieceType
score int
level int
lines int
gameOver bool
}
// newGame creates a new game state, optionally starting at a given level.
func newGame(startLevel int) *gameState {
g := &gameState{
level: startLevel,
next: pieceType(rand.IntN(numPieceTypes)),
}
g.spawnPiece()
return g
}
// spawnPiece pulls the next piece and generates a new next.
func (g *gameState) spawnPiece() {
g.current = g.next
g.next = pieceType(rand.IntN(numPieceTypes))
g.currentRot = 0
g.currentRow = 0
g.currentCol = spawnCol(g.current, 0)
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
g.gameOver = true
}
}
// canPlace checks whether the piece fits at the given position.
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
return false
}
if g.board[r][c].filled {
return false
}
}
return true
}
// moveLeft moves the current piece left if possible.
func (g *gameState) moveLeft() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
g.currentCol--
return true
}
return false
}
// moveRight moves the current piece right if possible.
func (g *gameState) moveRight() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
g.currentCol++
return true
}
return false
}
// moveDown moves the current piece down one row. Returns false if it cannot.
func (g *gameState) moveDown() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
g.currentRow++
return true
}
return false
}
// rotate rotates the current piece clockwise with wall kick attempts.
func (g *gameState) rotate() bool {
newRot := (g.currentRot + 1) % 4
// Try in-place first.
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
g.currentRot = newRot
return true
}
// Wall kick: try +-1 column offset.
for _, offset := range []int{-1, 1} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
// I piece: try +-2.
if g.current == pieceI {
for _, offset := range []int{-2, 2} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
}
return false
}
// ghostRow returns the row where the piece would land.
func (g *gameState) ghostRow() int {
row := g.currentRow
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
row++
}
return row
}
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
func (g *gameState) hardDrop() int {
ghost := g.ghostRow()
dropped := ghost - g.currentRow
g.currentRow = ghost
return dropped
}
// lockPiece writes the current piece into the board.
func (g *gameState) lockPiece() {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
g.board[r][c] = cell{filled: true, piece: g.current}
}
}
}
// clearLines removes completed rows and returns how many were cleared.
func (g *gameState) clearLines() int {
cleared := 0
for r := boardRows - 1; r >= 0; r-- {
full := true
for c := range boardCols {
if !g.board[r][c].filled {
full = false
break
}
}
if full {
cleared++
// Shift everything above down.
for rr := r; rr > 0; rr-- {
g.board[rr] = g.board[rr-1]
}
g.board[0] = [boardCols]cell{}
r++ // re-check this row since we shifted
}
}
return cleared
}
// NES-style scoring multipliers per lines cleared.
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
// addScore updates score, lines, and level after clearing rows.
func (g *gameState) addScore(linesCleared int) {
if linesCleared > 0 && linesCleared <= 4 {
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
}
g.lines += linesCleared
// Level up every 10 lines.
newLevel := g.lines / 10
if newLevel > g.level {
g.level = newLevel
}
}
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
// Returns the number of lines cleared.
func (g *gameState) afterLock() int {
g.lockPiece()
cleared := g.clearLines()
g.addScore(cleared)
g.spawnPiece()
return cleared
}
// tickInterval returns the gravity interval in milliseconds for the current level.
func tickInterval(level int) int {
return max(800-level*60, 100)
}

View File

@@ -0,0 +1,331 @@
package tetris
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
const (
screenTitle screen = iota
screenGame
screenGameOver
)
type tickMsg time.Time
type lockMsg time.Time
const lockDelay = 500 * time.Millisecond
type model struct {
sess *shell.SessionContext
difficulty string
screen screen
game *gameState
quitting bool
height int
keypresses int
locking bool // true when piece has landed and lock delay is active
}
func newModel(sess *shell.SessionContext, difficulty string) *model {
return &model{
sess: sess,
difficulty: difficulty,
screen: screenTitle,
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.quitting {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.height = msg.Height
return m, nil
case tea.KeyMsg:
m.keypresses++
if msg.Type == tea.KeyCtrlC {
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
tea.Quit,
)
}
}
switch m.screen {
case screenTitle:
return m.updateTitle(msg)
case screenGame:
return m.updateGame(msg)
case screenGameOver:
return m.updateGameOver(msg)
}
return m, nil
}
func (m *model) View() string {
var content string
switch m.screen {
case screenTitle:
content = m.titleView()
case screenGame:
content = gameView(m.game)
case screenGameOver:
content = m.gameOverView()
}
return gameFrame(content, m.height)
}
// --- Title screen ---
func (m *model) titleView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" Press any key to start"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
if _, ok := msg.(tea.KeyMsg); ok {
m.screen = screenGame
var startLevel int
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
}
return m, nil
}
// --- Game screen ---
func (m *model) scheduleTick() tea.Cmd {
ms := tickInterval(m.game.level)
if m.difficulty == "easy" {
ms = max(1000-m.game.level*60, 150)
}
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m *model) scheduleLock() tea.Cmd {
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
return lockMsg(t)
})
}
// performLock locks the piece, clears lines, and returns commands for logging
// and scheduling the next tick. Returns nil if game over (goToGameOver is
// included in the returned batch).
func (m *model) performLock() tea.Cmd {
m.locking = false
cleared := m.game.afterLock()
if m.game.gameOver {
return tea.Batch(
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
m.goToGameOver(),
)
}
var cmds []tea.Cmd
cmds = append(cmds, m.scheduleTick())
if cleared > 0 {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
prevLevel := (m.game.lines - cleared) / 10
if m.game.level > prevLevel {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
}
}
return tea.Batch(cmds...)
}
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case lockMsg:
if m.game.gameOver || !m.locking {
return m, nil
}
// Lock delay expired — lock the piece now.
return m, m.performLock()
case tickMsg:
if m.game.gameOver || m.locking {
return m, nil
}
if !m.game.moveDown() {
// Piece landed — start lock delay instead of locking immediately.
m.locking = true
return m, m.scheduleLock()
}
return m, m.scheduleTick()
case tea.KeyMsg:
if m.game.gameOver {
return m, nil
}
switch msg.String() {
case "left":
m.game.moveLeft()
// If piece can now drop further, cancel lock delay.
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "right":
m.game.moveRight()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "down":
if m.game.moveDown() {
m.game.score++ // soft drop bonus
if m.locking {
m.locking = false
}
}
case "up", "z":
m.game.rotate()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case " ":
m.locking = false
dropped := m.game.hardDrop()
m.game.score += dropped * 2
return m, m.performLock()
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
return m, nil
}
return m, nil
}
// --- Game over screen ---
func (m *model) goToGameOver() tea.Cmd {
m.screen = screenGameOver
return tea.ClearScreen
}
func (m *model) gameOverView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" GAME OVER"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" R - Play again"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" Q - Quit"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
if keyMsg, ok := msg.(tea.KeyMsg); ok {
switch keyMsg.String() {
case "r":
startLevel := 0
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
m.screen = screenGame
m.keypresses = 0
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
}
return m, nil
}
// Helper methods for safe access when game may be nil.
func (m *model) gameScore() int {
if m.game != nil {
return m.game.score
}
return 0
}
func (m *model) gameLevel() int {
if m.game != nil {
return m.game.level
}
return 0
}
func (m *model) gameLines() int {
if m.game != nil {
return m.game.lines
}
return 0
}
// logAction returns a tea.Cmd that logs an action to the session store.
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
return func() tea.Msg {
if sess.Store != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("tetris")
}
return nil
}
}

View File

@@ -0,0 +1,286 @@
package tetris
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
const termWidth = 80
var (
colorWhite = lipgloss.Color("#FFFFFF")
colorDim = lipgloss.Color("#555555")
colorBlack = lipgloss.Color("#000000")
colorGhost = lipgloss.Color("#333333")
)
var (
baseStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack)
dimStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
titleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#00FFFF")).
Background(colorBlack).
Bold(true)
sidebarLabelStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
sidebarValueStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack).
Bold(true)
)
// cellStyle returns a style for a filled cell of a given piece type.
func cellStyle(pt pieceType) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(pieceColors[pt]).
Background(colorBlack)
}
// ghostStyle returns a dimmed style for the ghost piece.
func ghostCellStyle() lipgloss.Style {
return lipgloss.NewStyle().
Foreground(colorGhost).
Background(colorBlack)
}
// renderBoard renders the board, current piece, and ghost piece as a string.
func renderBoard(g *gameState) string {
// Build a display grid that includes the current piece and ghost.
type displayCell struct {
filled bool
ghost bool
piece pieceType
}
var grid [boardRows][boardCols]displayCell
// Copy locked cells.
for r := range boardRows {
for c := range boardCols {
if g.board[r][c].filled {
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
}
}
}
// Ghost piece.
ghostR := g.ghostRow()
if ghostR != g.currentRow {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := ghostR+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
grid[r][c] = displayCell{ghost: true, piece: g.current}
}
}
}
// Current piece.
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
grid[r][c] = displayCell{filled: true, piece: g.current}
}
}
// Render grid.
var b strings.Builder
borderStyle := dimStyle
for _, row := range grid {
b.WriteString(borderStyle.Render("|"))
for _, dc := range row {
switch {
case dc.filled:
b.WriteString(cellStyle(dc.piece).Render("[]"))
case dc.ghost:
b.WriteString(ghostCellStyle().Render("::"))
default:
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(borderStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
return b.String()
}
// renderNextPiece renders the "next piece" preview box.
func renderNextPiece(pt pieceType) string {
shape := pieces[pt][0]
// Determine bounding box.
minR, maxR := shape[0][0], shape[0][0]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[0] < minR {
minR = off[0]
}
if off[0] > maxR {
maxR = off[0]
}
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
rows := maxR - minR + 1
cols := maxC - minC + 1
// Build a small grid.
grid := make([][]bool, rows)
for i := range grid {
grid[i] = make([]bool, cols)
}
for _, off := range shape {
grid[off[0]-minR][off[1]-minC] = true
}
var b strings.Builder
boxWidth := 8 // chars for the box interior
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
b.WriteString("\n")
for r := range rows {
b.WriteString(dimStyle.Render("|"))
// Center the piece in the box.
pieceWidth := cols * 2
leftPad := (boxWidth - pieceWidth) / 2
rightPad := boxWidth - pieceWidth - leftPad
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
for c := range cols {
if grid[r][c] {
b.WriteString(cellStyle(pt).Render("[]"))
} else {
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
// Fill remaining rows in the box (max 4 rows for I piece).
for r := rows; r < 2; r++ {
b.WriteString(dimStyle.Render("|"))
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
return b.String()
}
// formatScore formats a score with comma separators.
func formatScore(n int) string {
s := fmt.Sprintf("%d", n)
if len(s) <= 3 {
return s
}
var parts []string
for len(s) > 3 {
parts = append([]string{s[len(s)-3:]}, parts...)
s = s[:len(s)-3]
}
parts = append([]string{s}, parts...)
return strings.Join(parts, ",")
}
// gameView combines the board and sidebar into the game screen.
func gameView(g *gameState) string {
boardStr := renderBoard(g)
boardLines := strings.Split(boardStr, "\n")
nextStr := renderNextPiece(g.next)
nextLines := strings.Split(nextStr, "\n")
// Build sidebar lines.
var sidebar []string
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
sidebar = append(sidebar, nextLines...)
sidebar = append(sidebar, "")
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
sidebar = append(sidebar, "")
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
// Combine board and sidebar side by side.
var b strings.Builder
maxLines := max(len(boardLines), len(sidebar))
for i := range maxLines {
boardLine := ""
if i < len(boardLines) {
boardLine = boardLines[i]
}
sidebarLine := ""
if i < len(sidebar) {
sidebarLine = sidebar[i]
}
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
b.WriteString(boardLine)
b.WriteString(sidebarLine)
b.WriteString("\n")
}
return b.String()
}
// padLine pads a single line to termWidth.
func padLine(line string) string {
w := lipgloss.Width(line)
if w >= termWidth {
return line
}
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
}
// padLines pads every line in a multi-line string to termWidth.
func padLines(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = padLine(line)
}
return strings.Join(lines, "\n")
}
// gameFrame wraps content with padding to fill the terminal.
func gameFrame(content string, height int) string {
var b strings.Builder
b.WriteString(content)
// Pad with blank lines to fill terminal height.
if height > 0 {
contentLines := strings.Count(content, "\n") + 1
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
for i := contentLines; i < height; i++ {
b.WriteString(blankLine)
b.WriteString("\n")
}
}
return padLines(b.String())
}

View File

@@ -0,0 +1,66 @@
package tetris
import (
"context"
"io"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// TetrisShell is a Tetris game TUI for the honeypot.
type TetrisShell struct{}
// NewTetrisShell returns a new TetrisShell instance.
func NewTetrisShell() *TetrisShell {
return &TetrisShell{}
}
func (t *TetrisShell) Name() string { return "tetris" }
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
m := newModel(sess, difficulty)
p := tea.NewProgram(m,
tea.WithInput(rw),
tea.WithOutput(rw),
tea.WithAltScreen(),
)
done := make(chan error, 1)
go func() {
_, err := p.Run()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
p.Quit()
<-done
return nil
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,582 @@
package tetris
import (
"context"
"strings"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "player",
Store: store,
}
m := newModel(sess, "normal")
return m, store
}
// sendKey sends a single key message to the model and returns the command.
func sendKey(m *model, key string) tea.Cmd {
var msg tea.KeyMsg
switch key {
case "enter":
msg = tea.KeyMsg{Type: tea.KeyEnter}
case "up":
msg = tea.KeyMsg{Type: tea.KeyUp}
case "down":
msg = tea.KeyMsg{Type: tea.KeyDown}
case "left":
msg = tea.KeyMsg{Type: tea.KeyLeft}
case "right":
msg = tea.KeyMsg{Type: tea.KeyRight}
case "space":
msg = tea.KeyMsg{Type: tea.KeySpace}
case "ctrl+c":
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
default:
if len(key) == 1 {
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
}
}
_, cmd := m.Update(msg)
return cmd
}
// sendTick sends a tick message to the model.
func sendTick(m *model) tea.Cmd {
_, cmd := m.Update(tickMsg(time.Now()))
return cmd
}
// execCmds recursively executes tea.Cmd functions (including batches).
func execCmds(cmd tea.Cmd) {
if cmd == nil {
return
}
msg := cmd()
if batch, ok := msg.(tea.BatchMsg); ok {
for _, c := range batch {
execCmds(c)
}
}
}
func TestTetrisShellName(t *testing.T) {
sh := NewTetrisShell()
if sh.Name() != "tetris" {
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{
"difficulty": "hard",
}
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
t.Errorf("configString() = %q, want %q", got, "hard")
}
if got := configString(cfg, "missing", "normal"); got != "normal" {
t.Errorf("configString() = %q, want %q", got, "normal")
}
if got := configString(nil, "difficulty", "normal"); got != "normal" {
t.Errorf("configString(nil) = %q, want %q", got, "normal")
}
}
func TestTitleScreenRenders(t *testing.T) {
m, _ := newTestModel(t)
view := m.View()
if !strings.Contains(view, "████") {
t.Error("title screen should show TETRIS logo")
}
if !strings.Contains(view, "Press any key") {
t.Error("title screen should show 'Press any key'")
}
}
func TestTitleToGame(t *testing.T) {
m, _ := newTestModel(t)
if m.screen != screenTitle {
t.Fatalf("expected screenTitle, got %d", m.screen)
}
sendKey(m, "enter")
if m.screen != screenGame {
t.Errorf("expected screenGame after keypress, got %d", m.screen)
}
if m.game == nil {
t.Fatal("game should be initialized")
}
}
func TestGameRenders(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
view := m.View()
if !strings.Contains(view, "|") {
t.Error("game view should contain board borders")
}
if !strings.Contains(view, "SCORE") {
t.Error("game view should show SCORE")
}
if !strings.Contains(view, "LEVEL") {
t.Error("game view should show LEVEL")
}
if !strings.Contains(view, "LINES") {
t.Error("game view should show LINES")
}
if !strings.Contains(view, "NEXT") {
t.Error("game view should show NEXT")
}
}
// --- Pure game logic tests ---
func TestNewGame(t *testing.T) {
g := newGame(0)
if g.gameOver {
t.Error("new game should not be game over")
}
if g.score != 0 {
t.Errorf("initial score = %d, want 0", g.score)
}
if g.level != 0 {
t.Errorf("initial level = %d, want 0", g.level)
}
if g.lines != 0 {
t.Errorf("initial lines = %d, want 0", g.lines)
}
}
func TestNewGameHardLevel(t *testing.T) {
g := newGame(5)
if g.level != 5 {
t.Errorf("hard start level = %d, want 5", g.level)
}
}
func TestMoveLeft(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveLeft()
if g.currentCol != startCol-1 {
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
}
}
func TestMoveRight(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveRight()
if g.currentCol != startCol+1 {
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
}
}
func TestMoveDown(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
moved := g.moveDown()
if !moved {
t.Error("moveDown should succeed from starting position")
}
if g.currentRow != startRow+1 {
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
}
}
func TestCannotMoveLeftBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way left.
for range boardCols {
g.moveLeft()
}
col := g.currentCol
g.moveLeft() // should not move further
if g.currentCol != col {
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
}
}
func TestCannotMoveRightBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way right.
for range boardCols {
g.moveRight()
}
col := g.currentCol
g.moveRight() // should not move further
if g.currentCol != col {
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
}
}
func TestRotate(t *testing.T) {
g := newGame(0)
startRot := g.currentRot
g.rotate()
// Rotation should change (possibly with wall kick).
if g.currentRot == startRot {
// Rotation might legitimately fail in some edge cases, so just check
// that the game state is valid.
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
t.Error("piece should be in a valid position after rotate attempt")
}
}
}
func TestHardDrop(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
dropped := g.hardDrop()
if dropped == 0 {
t.Error("hard drop should move piece down at least some rows from top")
}
if g.currentRow <= startRow {
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
}
}
func TestGhostRow(t *testing.T) {
g := newGame(0)
ghost := g.ghostRow()
if ghost < g.currentRow {
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
}
// Ghost should be at a position where moving down one more is impossible.
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
t.Error("ghost row should be the lowest valid position")
}
}
func TestLockPiece(t *testing.T) {
g := newGame(0)
g.hardDrop()
pt := g.current
row, col, rot := g.currentRow, g.currentCol, g.currentRot
g.lockPiece()
// Verify that the piece's cells are now filled.
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if !g.board[r][c].filled {
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
}
}
}
func TestClearLines(t *testing.T) {
g := newGame(0)
// Fill the bottom row completely.
for c := range boardCols {
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
}
cleared := g.clearLines()
if cleared != 1 {
t.Errorf("clearLines() = %d, want 1", cleared)
}
// Bottom row should now be empty (shifted from above).
for c := range boardCols {
if g.board[boardRows-1][c].filled {
t.Errorf("bottom row col %d should be empty after clearing", c)
}
}
}
func TestClearMultipleLines(t *testing.T) {
g := newGame(0)
// Fill the bottom 4 rows.
for r := boardRows - 4; r < boardRows; r++ {
for c := range boardCols {
g.board[r][c] = cell{filled: true, piece: pieceI}
}
}
cleared := g.clearLines()
if cleared != 4 {
t.Errorf("clearLines() = %d, want 4", cleared)
}
}
func TestScoring(t *testing.T) {
tests := []struct {
lines int
level int
want int
}{
{1, 0, 40},
{2, 0, 100},
{3, 0, 300},
{4, 0, 1200},
{1, 1, 80},
{4, 2, 3600},
}
for _, tt := range tests {
g := newGame(tt.level)
g.addScore(tt.lines)
if g.score != tt.want {
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
}
}
}
func TestLevelUp(t *testing.T) {
g := newGame(0)
g.lines = 9
g.addScore(1) // This should push lines to 10, triggering level 1.
if g.level != 1 {
t.Errorf("level = %d, want 1 after 10 lines", g.level)
}
}
func TestTickInterval(t *testing.T) {
if got := tickInterval(0); got != 800 {
t.Errorf("tickInterval(0) = %d, want 800", got)
}
if got := tickInterval(5); got != 500 {
t.Errorf("tickInterval(5) = %d, want 500", got)
}
// Floor at 100ms.
if got := tickInterval(20); got != 100 {
t.Errorf("tickInterval(20) = %d, want 100", got)
}
}
func TestFormatScore(t *testing.T) {
tests := []struct {
n int
want string
}{
{0, "0"},
{100, "100"},
{1250, "1,250"},
{1000000, "1,000,000"},
}
for _, tt := range tests {
if got := formatScore(tt.n); got != tt.want {
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
}
}
}
func TestGameOverScreen(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Force game over.
m.game.gameOver = true
m.screen = screenGameOver
view := m.View()
if !strings.Contains(view, "GAME OVER") {
t.Error("game over screen should show GAME OVER")
}
if !strings.Contains(view, "Score") {
t.Error("game over screen should show score")
}
}
func TestRestartFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "r")
if m.screen != screenGame {
t.Errorf("expected screenGame after restart, got %d", m.screen)
}
if m.game.gameOver {
t.Error("game should not be over after restart")
}
}
func TestQuitFromGame(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q")
}
}
func TestQuitFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q in game over")
}
}
func TestSoftDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
scoreBefore := m.game.score
sendKey(m, "down")
if m.game.score != scoreBefore+1 {
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
}
}
func TestHardDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Hard drop gives 2 points per row dropped.
sendKey(m, "space")
if m.game.score < 2 {
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
}
}
func TestTickMovesDown(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
rowBefore := m.game.currentRow
sendTick(m)
// Piece should either move down by 1, or lock and spawn a new piece at top.
movedDown := m.game.currentRow == rowBefore+1
respawned := m.game.currentRow < rowBefore
if !movedDown && !respawned && !m.game.gameOver {
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
}
}
func TestSessionLogs(t *testing.T) {
m, store := newTestModel(t)
// Press key to start game — returns a logAction cmd.
cmd := sendKey(m, "enter")
if cmd != nil {
execCmds(cmd)
}
time.Sleep(50 * time.Millisecond)
found := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "GAME START") {
found = true
break
}
}
if !found {
t.Error("expected GAME START in session logs")
}
}
func TestKeypressCounter(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "left")
sendKey(m, "right")
sendKey(m, "down")
if m.keypresses != 4 { // enter + 3 game keys
t.Errorf("keypresses = %d, want 4", m.keypresses)
}
}
func TestLockDelay(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Drop piece to the bottom via ticks until it can't move down.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
if !m.locking {
t.Fatal("piece should be in locking state after hitting bottom")
}
// During lock delay, we should still be able to move left/right.
colBefore := m.game.currentCol
sendKey(m, "left")
if m.game.currentCol >= colBefore {
// Might not have moved if against wall, try right.
sendKey(m, "right")
}
// Sending a lockMsg should finalize the piece.
m.Update(lockMsg(time.Now()))
// After lock, a new piece should have spawned (row near top).
if m.game.currentRow > 1 && !m.game.gameOver {
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
}
}
func TestLockDelayCancelledByDrop(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Build a ledge: fill rows 18-19 but leave column 0 empty.
for r := boardRows - 2; r < boardRows; r++ {
for c := 1; c < boardCols; c++ {
m.game.board[r][c] = cell{filled: true, piece: pieceI}
}
}
// Move piece to column 0 area and drop it onto the ledge.
for range boardCols {
m.game.moveLeft()
}
// Tick down until locking.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
// If piece is on the ledge and we slide it to col 0 (open column),
// the lock delay should cancel since it can fall further.
// This test just validates the locking flag logic works.
if m.locking {
// Try moving — if piece can drop further, locking should cancel.
sendKey(m, "left")
// Whether locking cancels depends on the board state; just verify no crash.
}
}
func TestSpawnCol(t *testing.T) {
// All pieces should spawn roughly centered.
for pt := range pieceType(numPieceTypes) {
col := spawnCol(pt, 0)
if col < 0 || col > boardCols-1 {
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
}
// Verify piece fits at spawn position.
shape := pieces[pt][0]
for _, off := range shape {
c := col + off[1]
if c < 0 || c >= boardCols {
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
}
}
}
}

View File

@@ -0,0 +1,217 @@
package storage
import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus"
)
// InstrumentedStore wraps a Store and records query duration and errors
// as Prometheus metrics for each method call.
type InstrumentedStore struct {
store Store
queryDuration *prometheus.HistogramVec
queryErrors *prometheus.CounterVec
}
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
return &InstrumentedStore{
store: store,
queryDuration: queryDuration,
queryErrors: queryErrors,
}
}
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
v, err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return v, err
}
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return err
}
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
return observeErr(s, "RecordLoginAttempt", func() error {
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
})
}
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
return observe(s, "CreateSession", func() (string, error) {
return s.store.CreateSession(ctx, ip, username, shellName, country)
})
}
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
return observeErr(s, "EndSession", func() error {
return s.store.EndSession(ctx, sessionID, disconnectedAt)
})
}
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
return observeErr(s, "UpdateHumanScore", func() error {
return s.store.UpdateHumanScore(ctx, sessionID, score)
})
}
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
return observeErr(s, "SetExecCommand", func() error {
return s.store.SetExecCommand(ctx, sessionID, command)
})
}
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
return observeErr(s, "AppendSessionLog", func() error {
return s.store.AppendSessionLog(ctx, sessionID, input, output)
})
}
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
return s.store.DeleteRecordsBefore(ctx, cutoff)
})
}
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
return s.store.GetDashboardStats(ctx)
})
}
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
return s.store.GetTopUsernames(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
return s.store.GetTopPasswords(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
return s.store.GetTopIPs(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
return s.store.GetTopCountries(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
return s.store.GetTopExecCommands(ctx, limit)
})
}
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
return observe(s, "GetRecentSessions", func() ([]Session, error) {
return s.store.GetRecentSessions(ctx, limit, activeOnly)
})
}
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
})
}
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
return observe(s, "GetSession", func() (*Session, error) {
return s.store.GetSession(ctx, sessionID)
})
}
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
return s.store.GetSessionLogs(ctx, sessionID)
})
}
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
return observeErr(s, "AppendSessionEvents", func() error {
return s.store.AppendSessionEvents(ctx, events)
})
}
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
return s.store.GetSessionEvents(ctx, sessionID)
})
}
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
return observe(s, "CloseActiveSessions", func() (int64, error) {
return s.store.CloseActiveSessions(ctx, disconnectedAt)
})
}
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
return s.store.GetAttemptsOverTime(ctx, days, since, until)
})
}
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
return s.store.GetHourlyPattern(ctx, since, until)
})
}
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
return s.store.GetCountryStats(ctx)
})
}
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
return s.store.GetFilteredDashboardStats(ctx, f)
})
}
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
return s.store.GetFilteredTopUsernames(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
return s.store.GetFilteredTopPasswords(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
return s.store.GetFilteredTopIPs(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
return s.store.GetFilteredTopCountries(ctx, limit, f)
})
}
func (s *InstrumentedStore) Close() error {
return s.store.Close()
}

View File

@@ -0,0 +1,163 @@
package storage
import (
"context"
"errors"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_query_errors_total",
Help: "test",
}, []string{"method"})
store := NewMemoryStore()
return NewInstrumentedStore(store, dur, errs), dur, errs
}
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
m := &dto.Metric{}
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
return m.GetHistogram().GetSampleCount()
}
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
m := &dto.Metric{}
c.WithLabelValues(method).Write(m)
return m.GetCounter().GetValue()
}
func TestInstrumentedStoreDelegation(t *testing.T) {
s, dur, _ := newTestInstrumented()
ctx := context.Background()
// RecordLoginAttempt should delegate and record duration.
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// CreateSession should delegate and return a valid ID.
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if id == "" {
t.Fatal("CreateSession returned empty ID")
}
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// GetDashboardStats should delegate.
stats, err := s.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats: %v", err)
}
if stats == nil {
t.Fatal("GetDashboardStats returned nil")
}
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreErrorCounting(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_ec_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_ec_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
// Error should be counted.
err := s.EndSession(ctx, "nonexistent", time.Now())
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getHistogramCount(dur, "EndSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
if c := getCounterValue(errs, "EndSession"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
// Successful call should not increment error counter.
s2, _, errs2 := newTestInstrumented()
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
t.Fatalf("expected error count 0, got %f", c)
}
}
// errorStore is a Store that returns errors for all methods.
type errorStore struct {
MemoryStore
}
var errFake = errors.New("fake error")
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
return errFake
}
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
return errFake
}
func TestInstrumentedStoreObserveErr(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test2_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test2_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreClose(t *testing.T) {
s, _, _ := newTestInstrumented()
if err := s.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
}

View File

@@ -336,10 +336,26 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
m.mu.Lock()
defer m.mu.Unlock()
// Count events per session.
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
}
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, f), nil
}
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
// Compute event counts and input bytes per session.
eventCounts := make(map[string]int)
inputBytes := make(map[string]int64)
for _, e := range m.SessionEvents {
eventCounts[e.SessionID]++
if e.Direction == 0 {
inputBytes[e.SessionID] += int64(len(e.Data))
}
}
var sessions []Session
@@ -347,17 +363,54 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
if activeOnly && s.DisconnectedAt != nil {
continue
}
if !matchesSessionFilter(s, f) {
continue
}
sess := *s
sess.EventCount = eventCounts[s.ID]
sess.InputBytes = inputBytes[s.ID]
sessions = append(sessions, sess)
}
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
})
if f.SortBy == "input_bytes" {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].InputBytes > sessions[j].InputBytes
})
} else {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
})
}
if limit > 0 && len(sessions) > limit {
sessions = sessions[:limit]
}
return sessions, nil
return sessions
}
// matchesSessionFilter returns true if the session matches the given filter.
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
return false
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
return false
}
if f.IP != "" && s.IP != f.IP {
return false
}
if f.Country != "" && s.Country != f.Country {
return false
}
if f.Username != "" && s.Username != f.Username {
return false
}
if f.HumanScoreAboveZero {
if s.HumanScore == nil || *s.HumanScore <= 0 {
return false
}
}
return true
}
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
@@ -399,6 +452,258 @@ func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time
return count, nil
}
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
m.mu.Lock()
defer m.mu.Unlock()
var cutoff time.Time
if since != nil {
cutoff = *since
} else {
cutoff = time.Now().UTC().AddDate(0, 0, -days)
}
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.LastSeen.Before(cutoff) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
day := a.LastSeen.Format("2006-01-02")
counts[day] += int64(a.Count)
}
points := make([]TimeSeriesPoint, 0, len(counts))
for day, count := range counts {
t, _ := time.Parse("2006-01-02", day)
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
}
sort.Slice(points, func(i, j int) bool {
return points[i].Timestamp.Before(points[j].Timestamp)
})
return points, nil
}
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
hourCounts := make(map[int]int64)
for _, a := range m.LoginAttempts {
if since != nil && a.LastSeen.Before(*since) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
hour := a.LastSeen.Hour()
hourCounts[hour] += int64(a.Count)
}
counts := make([]HourlyCount, 0, len(hourCounts))
for h, c := range hourCounts {
counts = append(counts, HourlyCount{Hour: h, Count: c})
}
sort.Slice(counts, func(i, j int) bool {
return counts[i].Hour < counts[j].Hour
})
return counts, nil
}
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
result := make([]CountryCount, 0, len(counts))
for country, count := range counts {
result = append(result, CountryCount{Country: country, Count: count})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Count > result[j].Count
})
return result, nil
}
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
if f.Since != nil && a.LastSeen.Before(*f.Since) {
return false
}
if f.Until != nil && a.LastSeen.After(*f.Until) {
return false
}
if f.IP != "" && a.IP != f.IP {
return false
}
if f.Country != "" && a.Country != f.Country {
return false
}
if f.Username != "" && a.Username != f.Username {
return false
}
return true
}
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
m.mu.Lock()
defer m.mu.Unlock()
stats := &DashboardStats{}
ips := make(map[string]struct{})
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
stats.TotalAttempts += int64(a.Count)
ips[a.IP] = struct{}{}
}
stats.UniqueIPs = int64(len(ips))
for _, s := range m.Sessions {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
continue
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
continue
}
if f.IP != "" && s.IP != f.IP {
continue
}
if f.Country != "" && s.Country != f.Country {
continue
}
stats.TotalSessions++
if s.DisconnectedAt == nil {
stats.ActiveSessions++
}
}
return stats, nil
}
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("username", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("password", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if a.Country == "" {
continue
}
if !matchesFilter(a, f) {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
var key string
switch field {
case "username":
key = a.Username
case "password":
key = a.Password
case "ip":
key = a.IP
}
counts[key] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries
}
func (m *MemoryStore) Close() error {
return nil
}

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);

View File

@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err)
}
if version != 4 {
t.Errorf("version = %d, want 4", version)
if version != 5 {
t.Errorf("version = %d, want 5", version)
}
// Verify tables exist by inserting into them.
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err)
}
if version != 4 {
t.Errorf("version = %d after double migrate, want 4", version)
if version != 5 {
t.Errorf("version = %d after double migrate, want 5", version)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
@@ -382,40 +383,104 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
}
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
if activeOnly {
query += ` WHERE s.disconnected_at IS NULL`
}
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, limit)
return s.scanSessions(ctx, query, limit)
}
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
var clauses []string
var args []any
if activeOnly {
clauses = append(clauses, "s.disconnected_at IS NULL")
}
if f.Since != nil {
clauses = append(clauses, "s.connected_at >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "s.connected_at <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "s.ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "s.country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "s.username = ?")
args = append(args, f.Username)
}
if f.HumanScoreAboveZero {
clauses = append(clauses, "s.human_score > 0")
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
var validSessionSorts = map[string]string{
"connected_at": "s.connected_at DESC",
"input_bytes": "input_bytes DESC",
}
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
where, args := buildSessionWhereClause(f, activeOnly)
args = append(args, limit)
orderBy := validSessionSorts["connected_at"]
if mapped, ok := validSessionSorts[f.SortBy]; ok {
orderBy = mapped
}
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
return s.scanSessions(ctx, query, args...)
}
// scanSessions executes a session query and scans the results.
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying recent sessions: %w", err)
return nil, fmt.Errorf("querying sessions: %w", err)
}
defer func() { _ = rows.Close() }()
var sessions []Session
for rows.Next() {
var s Session
var sess Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &s.EventCount); err != nil {
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
return nil, fmt.Errorf("scanning session: %w", err)
}
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
s.DisconnectedAt = &t
sess.DisconnectedAt = &t
}
if humanScore.Valid {
s.HumanScore = &humanScore.Float64
sess.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
s.ExecCommand = &execCommand.String
sess.ExecCommand = &execCommand.String
}
sessions = append(sessions, s)
sessions = append(sessions, sess)
}
return sessions, rows.Err()
}
@@ -454,6 +519,265 @@ func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt ti
return res.RowsAffected()
}
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
} else {
query += ` AND last_seen >= ?`
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY d ORDER BY d`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying attempts over time: %w", err)
}
defer func() { _ = rows.Close() }()
var points []TimeSeriesPoint
for rows.Next() {
var dateStr string
var p TimeSeriesPoint
if err := rows.Scan(&dateStr, &p.Count); err != nil {
return nil, fmt.Errorf("scanning time series point: %w", err)
}
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
points = append(points, p)
}
return points, rows.Err()
}
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY h ORDER BY h`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying hourly pattern: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []HourlyCount
for rows.Next() {
var c HourlyCount
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
return nil, fmt.Errorf("scanning hourly count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC`)
if err != nil {
return nil, fmt.Errorf("querying country stats: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []CountryCount
for rows.Next() {
var c CountryCount
if err := rows.Scan(&c.Country, &c.Count); err != nil {
return nil, fmt.Errorf("scanning country count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
var clauses []string
var args []any
if f.Since != nil {
clauses = append(clauses, "last_seen >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "last_seen <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "username = ?")
args = append(args, f.Username)
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
where, args := buildAttemptWhereClause(f)
stats := &DashboardStats{}
err := s.db.QueryRowContext(ctx,
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
if err != nil {
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
}
// Sessions don't have username/password, so only filter by time, IP, country.
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
var sessArgs []any
if f.Since != nil {
sessQuery += ` AND connected_at >= ?`
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
sessQuery += ` AND connected_at <= ?`
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
sessQuery += ` AND ip = ?`
sessArgs = append(sessArgs, f.IP)
}
if f.Country != "" {
sessQuery += ` AND country = ?`
sessArgs = append(sessArgs, f.Country)
}
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
}
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
}
return stats, nil
}
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "username", limit, f)
}
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "password", limit, f)
}
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
countryClause := "country != ''"
if where == "" {
where = " WHERE " + countryClause
} else {
where += " AND " + countryClause
}
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
switch column {
case "username", "password":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts%s
GROUP BY %s
ORDER BY total DESC
LIMIT ?`, column, where, column)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) Close() error {
return s.db.Close()
}

View File

@@ -29,6 +29,7 @@ type Session struct {
HumanScore *float64
ExecCommand *string
EventCount int
InputBytes int64
}
// SessionLog represents a single log entry for a session.
@@ -56,6 +57,35 @@ type DashboardStats struct {
ActiveSessions int64
}
// TimeSeriesPoint represents a single data point in a time series.
type TimeSeriesPoint struct {
Timestamp time.Time
Count int64
}
// HourlyCount represents the total attempts for a given hour of day.
type HourlyCount struct {
Hour int // 0-23
Count int64
}
// CountryCount represents the total attempts from a given country.
type CountryCount struct {
Country string
Count int64
}
// DashboardFilter contains optional filters for dashboard queries.
type DashboardFilter struct {
Since *time.Time
Until *time.Time
IP string
Country string
Username string
HumanScoreAboveZero bool
SortBy string
}
// TopEntry represents a value and its count for top-N queries.
type TopEntry struct {
Value string
@@ -110,6 +140,10 @@ type Store interface {
// If activeOnly is true, only sessions with no disconnected_at are returned.
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
// GetFilteredSessions returns sessions matching the given filter, ordered
// by the filter's SortBy field (default: connected_at DESC).
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
// GetSession returns a single session by ID.
GetSession(ctx context.Context, sessionID string) (*Session, error)
@@ -127,6 +161,30 @@ type Store interface {
// sessions left over from a previous unclean shutdown.
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
// GetAttemptsOverTime returns daily attempt counts for the last N days.
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
// GetCountryStats returns total attempts per country, ordered by count DESC.
GetCountryStats(ctx context.Context) ([]CountryCount, error)
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
// GetFilteredTopUsernames returns top usernames with optional filters applied.
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopPasswords returns top passwords with optional filters applied.
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopIPs returns top IPs with optional filters applied.
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopCountries returns top countries with optional filters applied.
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// Close releases any resources held by the store.
Close() error
}

View File

@@ -424,6 +424,226 @@ func TestSetExecCommand(t *testing.T) {
})
}
func seedChartData(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
// Record attempts with country data from different IPs.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
}
func TestGetAttemptsOverTime(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
if len(points) != 0 {
t.Errorf("expected empty, got %v", points)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
// All data was inserted today, so should be one point.
if len(points) != 1 {
t.Fatalf("len = %d, want 1", len(points))
}
// 5 + 3 + 2 = 10 total.
if points[0].Count != 10 {
t.Errorf("count = %d, want 10", points[0].Count)
}
})
})
}
func TestGetHourlyPattern(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
// All data was inserted at the same hour.
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Count != 10 {
t.Errorf("count = %d, want 10", counts[0].Count)
}
})
})
}
func TestGetCountryStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 2 {
t.Fatalf("len = %d, want 2", len(counts))
}
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
if counts[0].Country != "CN" || counts[0].Count != 7 {
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
}
if counts[1].Country != "RU" || counts[1].Count != 3 {
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
}
})
t.Run("excludes empty country", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
t.Fatalf("seeding: %v", err)
}
counts, err := store.GetCountryStats(ctx)
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Country != "US" {
t.Errorf("country = %q, want US", counts[0].Country)
}
})
})
}
func TestGetFilteredDashboardStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("no filter", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 10 {
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
}
})
t.Run("filter by country", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
// CN: 5 + 2 = 7
if stats.TotalAttempts != 7 {
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
}
})
t.Run("filter by IP", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 5 {
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
}
})
t.Run("filter by username", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 3 {
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
}
})
})
}
func TestGetFilteredTopUsernames(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
seedChartData(t, store)
// Filter by country CN should only show root.
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredTopUsernames: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len = %d, want 1", len(entries))
}
if entries[0].Value != "root" || entries[0].Count != 7 {
t.Errorf("entries[0] = %+v, want root/7", entries[0])
}
})
}
func TestGetRecentSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
@@ -480,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
})
})
}
func TestInputBytes(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("counts only input direction", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
events := []SessionEvent{
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
if sessions[0].InputBytes != 7 {
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
}
})
t.Run("zero when no events", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].InputBytes != 0 {
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
}
})
})
}
func TestGetFilteredSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("filter by human score", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Create two sessions, one with human score > 0.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
t.Run("sort by input bytes", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Session with more input (created first).
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Session with less input (created after id1, so would be first by connected_at).
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
time.Sleep(1100 * time.Millisecond)
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Default sort (connected_at DESC) should show id2 first.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id2 {
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
}
// Sort by input_bytes should show id1 first (more input).
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
}
})
t.Run("combined filters", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Different country, also has score.
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Same country CN but no score.
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Filter: CN + human score > 0 -> only id1.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
Country: "CN",
HumanScoreAboveZero: true,
})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
})
}

View File

@@ -1,13 +1,23 @@
package web
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// dbContext returns a context detached from the HTTP request lifecycle with a
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
// queries when the browser aborts the previous XHR.
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
}
type dashboardData struct {
Stats *storage.DashboardStats
TopUsernames []storage.TopEntry
@@ -20,7 +30,8 @@ type dashboardData struct {
}
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
@@ -96,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
stats, err := s.store.GetDashboardStats(r.Context())
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
s.logger.Error("failed to get dashboard stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -110,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
ctx, cancel := dbContext(r)
defer cancel()
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -123,6 +140,24 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
}
}
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
if err != nil {
s.logger.Error("failed to get filtered sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
s.logger.Error("failed to render recent sessions fragment", "err", err)
}
}
type sessionDetailData struct {
Session *storage.Session
Logs []storage.SessionLog
@@ -130,7 +165,8 @@ type sessionDetailData struct {
}
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
session, err := s.store.GetSession(ctx, sessionID)
@@ -180,8 +216,201 @@ type apiEventsResponse struct {
Events []apiEvent `json:"events"`
}
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
func parseDateParam(r *http.Request, name string) *time.Time {
v := r.URL.Query().Get(name)
if v == "" {
return nil
}
t, err := time.Parse("2006-01-02", v)
if err != nil {
return nil
}
// For "until" dates, set to end of day.
if name == "until" {
t = t.Add(24*time.Hour - time.Second)
}
return &t
}
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
return storage.DashboardFilter{
Since: parseDateParam(r, "since"),
Until: parseDateParam(r, "until"),
IP: r.URL.Query().Get("ip"),
Country: r.URL.Query().Get("country"),
Username: r.URL.Query().Get("username"),
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
SortBy: r.URL.Query().Get("sort"),
}
}
type apiTimeSeriesPoint struct {
Date string `json:"date"`
Count int64 `json:"count"`
}
type apiAttemptsOverTimeResponse struct {
Points []apiTimeSeriesPoint `json:"points"`
}
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
days := 30
if v := r.URL.Query().Get("days"); v != "" {
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
days = d
}
}
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
ctx, cancel := dbContext(r)
defer cancel()
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
if err != nil {
s.logger.Error("failed to get attempts over time", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
for i, p := range points {
resp.Points[i] = apiTimeSeriesPoint{
Date: p.Timestamp.Format("2006-01-02"),
Count: p.Count,
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode attempts over time", "err", err)
}
}
type apiHourlyCount struct {
Hour int `json:"hour"`
Count int64 `json:"count"`
}
type apiHourlyPatternResponse struct {
Hours []apiHourlyCount `json:"hours"`
}
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
counts, err := s.store.GetHourlyPattern(ctx, since, until)
if err != nil {
s.logger.Error("failed to get hourly pattern", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
for i, c := range counts {
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode hourly pattern", "err", err)
}
}
type apiCountryCount struct {
Country string `json:"country"`
Count int64 `json:"count"`
}
type apiCountryStatsResponse struct {
Countries []apiCountryCount `json:"countries"`
}
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
counts, err := s.store.GetCountryStats(ctx)
if err != nil {
s.logger.Error("failed to get country stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
for i, c := range counts {
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode country stats", "err", err)
}
}
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
if err != nil {
s.logger.Error("failed to get filtered stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top usernames", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top passwords", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top IPs", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
data := dashboardData{
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
s.logger.Error("failed to render dashboard content fragment", "err", err)
}
}
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
events, err := s.store.GetSessionEvents(ctx, sessionID)

14
internal/web/static/chart.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,275 @@
(function() {
'use strict';
// Chart.js theme for Pico dark mode
Chart.defaults.color = '#b0b0b8';
Chart.defaults.borderColor = '#3a3a4a';
var attemptsChart = null;
var hourlyChart = null;
function getFilterParams() {
var form = document.getElementById('filter-form');
if (!form) return '';
var params = new URLSearchParams();
var since = form.elements['since'].value;
var until = form.elements['until'].value;
if (since) params.set('since', since);
if (until) params.set('until', until);
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
return params.toString();
}
function initAttemptsChart() {
var canvas = document.getElementById('chart-attempts');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
var labels = data.points.map(function(p) { return p.date; });
var values = data.points.map(function(p) { return p.count; });
if (attemptsChart) {
attemptsChart.data.labels = labels;
attemptsChart.data.datasets[0].data = values;
attemptsChart.update();
return;
}
attemptsChart = new Chart(ctx, {
type: 'line',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
borderColor: '#6366f1',
backgroundColor: 'rgba(99, 102, 241, 0.1)',
fill: true,
tension: 0.3,
pointRadius: 2
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initHourlyChart() {
var canvas = document.getElementById('chart-hourly');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
// Fill all 24 hours, defaulting to 0
var hourMap = {};
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
var labels = [];
var values = [];
for (var i = 0; i < 24; i++) {
labels.push(i + ':00');
values.push(hourMap[i] || 0);
}
if (hourlyChart) {
hourlyChart.data.labels = labels;
hourlyChart.data.datasets[0].data = values;
hourlyChart.update();
return;
}
hourlyChart = new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
backgroundColor: 'rgba(99, 102, 241, 0.6)',
borderColor: '#6366f1',
borderWidth: 1
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initWorldMap() {
var container = document.getElementById('world-map');
if (!container) return;
fetch('/static/world.svg')
.then(function(r) { return r.text(); })
.then(function(svgText) {
container.innerHTML = svgText;
fetch('/api/charts/country-stats')
.then(function(r) { return r.json(); })
.then(function(data) {
colorMap(container, data.countries);
});
});
}
function colorMap(container, countries) {
if (!countries || countries.length === 0) return;
var maxCount = countries[0].count; // already sorted DESC
var logMax = Math.log(maxCount + 1);
// Build lookup
var lookup = {};
countries.forEach(function(c) {
lookup[c.country.toLowerCase()] = c.count;
});
// Create tooltip element
var tooltip = document.createElement('div');
tooltip.id = 'map-tooltip';
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
document.body.appendChild(tooltip);
var svg = container.querySelector('svg');
if (!svg) return;
// Remove SVG title to prevent browser native tooltip
var svgTitle = svg.querySelector('title');
if (svgTitle) svgTitle.remove();
// Select both <path id="xx"> and <g id="xx"> country elements
var elements = svg.querySelectorAll('path[id], g[id]');
elements.forEach(function(el) {
var id = el.id.toLowerCase();
if (id.charAt(0) === '_') return; // skip non-country paths
var count = lookup[id];
if (count) {
var intensity = Math.log(count + 1) / logMax;
var r = Math.round(30 + intensity * 69); // 30 -> 99
var g = Math.round(30 + intensity * 72); // 30 -> 102
var b = Math.round(62 + intensity * 179); // 62 -> 241
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
// For <g> elements, color child paths; for <path>, color directly
if (el.tagName.toLowerCase() === 'g') {
el.querySelectorAll('path').forEach(function(p) {
p.style.fill = color;
});
} else {
el.style.fill = color;
}
}
el.addEventListener('mouseenter', function(e) {
var cc = id.toUpperCase();
var n = lookup[id] || 0;
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
tooltip.style.display = 'block';
});
el.addEventListener('mousemove', function(e) {
tooltip.style.left = (e.clientX + 12) + 'px';
tooltip.style.top = (e.clientY - 10) + 'px';
});
el.addEventListener('mouseleave', function() {
tooltip.style.display = 'none';
});
el.addEventListener('click', function() {
var input = document.querySelector('#filter-form input[name="country"]');
if (input) {
input.value = id.toUpperCase();
applyFilters();
}
});
el.style.cursor = 'pointer';
});
}
function applyFilters() {
// Re-fetch charts with filter params
initAttemptsChart();
initHourlyChart();
// Re-fetch dashboard content via htmx
var form = document.getElementById('filter-form');
if (!form) return;
var params = new URLSearchParams();
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
var val = form.elements[name].value;
if (val) params.set(name, val);
});
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
var target = document.getElementById('dashboard-content');
if (target) {
var url = '/fragments/dashboard-content?' + params.toString();
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
}
// Server-side filter for recent sessions table
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
}
window.clearFilters = function() {
var form = document.getElementById('filter-form');
if (form) {
form.reset();
applyFilters();
}
};
window.applyFilters = applyFilters;
// Initialize on DOM ready
document.addEventListener('DOMContentLoaded', function() {
initAttemptsChart();
initHourlyChart();
initWorldMap();
var form = document.getElementById('filter-form');
if (form) {
form.addEventListener('submit', function(e) {
e.preventDefault();
applyFilters();
});
}
});
})();

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 55 KiB

View File

@@ -56,6 +56,20 @@ func templateFuncMap() template.FuncMap {
}
return s
},
"formatBytes": func(b int64) string {
const (
kb = 1024
mb = 1024 * kb
)
switch {
case b >= mb:
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
case b >= kb:
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
default:
return fmt.Sprintf("%d B", b)
}
},
}
}
@@ -67,6 +81,7 @@ func loadTemplates() (*templateSet, error) {
"templates/dashboard.html",
"templates/fragments/stats.html",
"templates/fragments/active_sessions.html",
"templates/fragments/recent_sessions.html",
)
if err != nil {
return nil, fmt.Errorf("parsing dashboard templates: %w", err)

View File

@@ -3,6 +3,86 @@
{{template "stats" .Stats}}
</section>
<details>
<summary>Filters</summary>
<form id="filter-form">
<div class="grid">
<label>Since <input type="date" name="since"></label>
<label>Until <input type="date" name="until"></label>
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
<label>Username <input type="text" name="username" placeholder="root"></label>
</div>
<div class="grid">
<label><input type="checkbox" name="human_score" value="1"> Human score &gt; 0</label>
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
</div>
<button type="submit">Apply</button>
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
</form>
</details>
<section>
<h3>Attack Trends</h3>
<div class="grid">
<article>
<header>Attempts Over Time</header>
<canvas id="chart-attempts"></canvas>
</article>
<article>
<header>Hourly Pattern (UTC)</header>
<canvas id="chart-hourly"></canvas>
</article>
</div>
</section>
<section>
<h3>Attack Origins</h3>
<article>
<div id="world-map"></div>
</article>
</section>
<div id="dashboard-content">
{{template "dashboard_content" .}}
</div>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table id="recent-sessions-table">
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{template "recent_sessions" .RecentSessions}}
</tbody>
</table>
</section>
{{end}}
{{define "scripts"}}
<script src="/static/chart.min.js"></script>
<script src="/static/dashboard.js"></script>
{{end}}
{{define "dashboard_content"}}
<section>
<h3>Top Credentials & IPs</h3>
<div class="top-grid">
@@ -83,45 +163,4 @@
</article>
</div>
</section>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table>
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{range .RecentSessions}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="8">No sessions</td></tr>
{{end}}
</tbody>
</table>
</section>
{{end}}

View File

@@ -8,6 +8,7 @@
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
</tr>
</thead>
@@ -20,10 +21,11 @@
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
</tr>
{{else}}
<tr><td colspan="7">No active sessions</td></tr>
<tr><td colspan="8">No active sessions</td></tr>
{{end}}
</tbody>
</table>

View File

@@ -0,0 +1,17 @@
{{define "recent_sessions"}}
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="9">No sessions</td></tr>
{{end}}
{{end}}

View File

@@ -29,9 +29,16 @@
}
.top-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
gap: 1rem;
}
.top-grid article {
overflow: hidden;
min-width: 0;
}
#world-map svg { width: 100%; height: auto; }
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
nav h1 {
margin: 0;
}
@@ -52,5 +59,6 @@
<main class="container">
{{block "content" .}}{{end}}
</main>
{{block "scripts" .}}{{end}}
</body>
</html>

View File

@@ -7,7 +7,7 @@ import (
"net/http"
"strings"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
//go:embed static/*
@@ -40,9 +40,14 @@ func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Han
s.mux.Handle("GET /static/", http.FileServerFS(staticFS))
s.mux.HandleFunc("GET /sessions/{id}", s.handleSessionDetail)
s.mux.HandleFunc("GET /api/sessions/{id}/events", s.handleAPISessionEvents)
s.mux.HandleFunc("GET /api/charts/attempts-over-time", s.handleAPIAttemptsOverTime)
s.mux.HandleFunc("GET /api/charts/hourly-pattern", s.handleAPIHourlyPattern)
s.mux.HandleFunc("GET /api/charts/country-stats", s.handleAPICountryStats)
s.mux.HandleFunc("GET /", s.handleDashboard)
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
if metricsHandler != nil {
h := metricsHandler

View File

@@ -10,8 +10,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func newTestServer(t *testing.T) *Server {
@@ -54,6 +54,30 @@ func newSeededTestServer(t *testing.T) *Server {
return srv
}
func TestDbContextNotCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx)
dbCtx, dbCancel := dbContext(req)
defer dbCancel()
// Cancel the original request context.
cancel()
// The DB context should still be usable.
select {
case <-dbCtx.Done():
t.Fatal("dbContext should not be canceled when request context is canceled")
default:
}
// Verify the DB context has a deadline (from the timeout).
if _, ok := dbCtx.Deadline(); !ok {
t.Error("dbContext should have a deadline")
}
}
func TestDashboardHandler(t *testing.T) {
t.Run("empty store", func(t *testing.T) {
srv := newTestServer(t)
@@ -395,6 +419,135 @@ func TestDashboardExecCommands(t *testing.T) {
}
}
func TestAPIAttemptsOverTime(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
ct := w.Header().Get("Content-Type")
if !strings.Contains(ct, "application/json") {
t.Errorf("Content-Type = %q, want application/json", ct)
}
var resp apiAttemptsOverTimeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
// Seeded data inserted today -> at least 1 point.
if len(resp.Points) == 0 {
t.Error("expected at least one data point")
}
}
func TestAPIHourlyPattern(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
var resp apiHourlyPatternResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.Hours) == 0 {
t.Error("expected at least one hourly data point")
}
}
func TestAPICountryStats(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding: %v", err)
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
var resp apiCountryStatsResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.Countries) != 2 {
t.Fatalf("len = %d, want 2", len(resp.Countries))
}
}
func TestFragmentDashboardContent(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if strings.Contains(body, "<!DOCTYPE html>") {
t.Error("dashboard content fragment should not contain full HTML document")
}
if !strings.Contains(body, "Top Usernames") {
t.Error("dashboard content fragment should contain 'Top Usernames'")
}
}
func TestFragmentDashboardContentWithFilter(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding: %v", err)
}
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
// When filtered by CN, should show root but not admin.
if !strings.Contains(body, "root") {
t.Error("response should contain 'root' when filtered by CN")
}
}
func TestStaticAssets(t *testing.T) {
srv := newTestServer(t)
@@ -404,6 +557,9 @@ func TestStaticAssets(t *testing.T) {
}{
{"/static/pico.min.css", "text/css"},
{"/static/htmx.min.js", "text/javascript"},
{"/static/chart.min.js", "text/javascript"},
{"/static/dashboard.js", "text/javascript"},
{"/static/world.svg", "image/svg+xml"},
}
for _, tt := range tests {

View File

@@ -34,6 +34,16 @@ password = "admin"
# password = "cisco"
# shell = "cisco"
# [[auth.static_credentials]]
# username = "irobot"
# password = "roomba"
# shell = "roomba"
# [[auth.static_credentials]]
# username = "player"
# password = "tetris"
# shell = "tetris"
[storage]
db_path = "oubliette.db"
retention_days = 90
@@ -50,6 +60,12 @@ hostname = "ubuntu-server"
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
# fake_user = "" # override username in prompt; empty = use authenticated user
# Map usernames to specific shells (regardless of how auth succeeded).
# Credential-specific shell overrides take priority over username routes.
# [shell.username_routes]
# postgres = "psql"
# admin = "bash"
# Per-shell configuration (optional).
# [shell.banking]
# bank_name = "SECUREBANK"
@@ -65,6 +81,16 @@ hostname = "ubuntu-server"
# ios_version = "15.0(2)SE11"
# enable_password = "" # empty = accept after 1 failed attempt
# [shell.psql]
# db_name = "postgres"
# pg_version = "15.4"
# [shell.roomba]
# No configuration options currently.
# [shell.tetris]
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
# [detection]
# enabled = true
# threshold = 0.6 # 0.01.0, sessions above this trigger notifications