Compare commits
22 Commits
090dbec390
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
1b28f10ca8
|
|||
|
664e79fce6
|
|||
|
c74313c195
|
|||
|
9783ae5865
|
|||
|
62de222488
|
|||
| c9d143d84b | |||
|
d18a904ed5
|
|||
|
cb7be28f42
|
|||
|
0908b43724
|
|||
|
52310f588d
|
|||
|
b52216bd2f
|
|||
|
2bc83a17dd
|
|||
|
faf6e2abd7
|
|||
|
0a4eac188a
|
|||
|
7c90c9ed4a
|
|||
|
8a631af0d2
|
|||
|
40fda3420c
|
|||
|
c4801e3309
|
|||
|
4f10a8a422
|
|||
|
0b44d1c83f
|
|||
|
0133d956a5
|
|||
|
3c20e854aa
|
60
PLAN.md
60
PLAN.md
@@ -171,7 +171,20 @@ Goal: Add the entertaining shell implementations.
|
||||
### 3.5 Banking TUI Shell ✅
|
||||
- 80s-style green-on-black bank terminal
|
||||
|
||||
### 3.6 Other Shell Ideas (Future)
|
||||
### 3.6 PostgreSQL psql Shell ✅
|
||||
- Simulates psql interactive terminal with `db_name` and `pg_version` config
|
||||
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
|
||||
- SQL statement handling with multi-line buffering (semicolon-terminated)
|
||||
- Canned responses for common queries (SELECT version(), current_database(), etc.)
|
||||
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
|
||||
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
|
||||
|
||||
### 3.7 Roomba Shell ✅
|
||||
- iRobot Roomba j7+ vacuum robot interface
|
||||
- Status, cleaning, scheduling, diagnostics, floor map
|
||||
- Humorous history entries (cat encounters, sock tangles, sticky substances)
|
||||
|
||||
### 3.8 Other Shell Ideas (Future)
|
||||
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
|
||||
- **ELIZA therapist:** every response is a therapy question
|
||||
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
|
||||
@@ -183,11 +196,11 @@ Goal: Add the entertaining shell implementations.
|
||||
|
||||
Goal: Make the web UI great and add operational niceties.
|
||||
|
||||
### 4.1 Enhanced Web UI
|
||||
- GeoIP lookups and world map visualization of attack sources
|
||||
- Charts: attempts over time, hourly patterns, credential trends
|
||||
- Session detail view with full command log
|
||||
- Filtering and search
|
||||
### 4.1 Enhanced Web UI ✅
|
||||
- GeoIP lookups and world map visualization of attack sources ✅
|
||||
- Charts: attempts over time, hourly patterns, credential trends ✅
|
||||
- Session detail view with full command log ✅
|
||||
- Filtering and search ✅
|
||||
|
||||
### 4.2 Operational ✅
|
||||
- Prometheus metrics endpoint ✅
|
||||
@@ -200,3 +213,38 @@ Goal: Make the web UI great and add operational niceties.
|
||||
- Embed a lightweight GeoIP database or use an API ✅
|
||||
- Store country/city with each attempt ✅
|
||||
- Aggregate stats by country ✅
|
||||
|
||||
### 4.4 Capture SSH Exec Commands ✅
|
||||
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
|
||||
|
||||
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
|
||||
- Parse the command string from the exec payload ✅
|
||||
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
|
||||
- Store the command on the session record before closing the channel ✅
|
||||
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
|
||||
- Surface exec commands in the web UI (session detail view) ✅
|
||||
|
||||
#### 4.4.1 Fake Exec Output
|
||||
Return plausible fake output for exec commands to encourage bots to interact further.
|
||||
|
||||
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
|
||||
|
||||
Implementation:
|
||||
- A map of common command/variable patterns to fake output strings, e.g.:
|
||||
- `uname -a` / `uname -s -v -n -m` → `"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
|
||||
- `uname -m` / `arch` → `"x86_64"`
|
||||
- `cat /proc/uptime` → `"86432.71 172801.55"`
|
||||
- `nproc` / `grep -c "^processor" /proc/cpuinfo` → `"2"`
|
||||
- `cat /proc/cpuinfo` → fake cpuinfo block
|
||||
- `lspci` → empty (no GPU — discourages cryptominer targeting)
|
||||
- `id` → `"uid=0(root) gid=0(root) groups=0(root)"`
|
||||
- `cat /etc/passwd` → minimal fake passwd file
|
||||
- `last` → fake login entries
|
||||
- `cat --help`, `ls --help` → canned GNU coreutils help text
|
||||
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
|
||||
- If we recognise echo patterns, assemble and return the expected output
|
||||
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
|
||||
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
|
||||
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
|
||||
|
||||
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.
|
||||
|
||||
@@ -34,7 +34,8 @@ Key settings:
|
||||
- `auth.accept_after` — accept login after N failures per IP (default `10`)
|
||||
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
|
||||
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
|
||||
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
|
||||
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
||||
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
||||
- `storage.retention_interval` — how often to run retention (default `1h`)
|
||||
@@ -43,6 +44,7 @@ Key settings:
|
||||
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
|
||||
- `web.enabled` — enable the web dashboard (default `false`)
|
||||
- `web.listen_addr` — web dashboard listen address (default `:8080`)
|
||||
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
|
||||
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
|
||||
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
|
||||
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
|
||||
@@ -69,6 +71,9 @@ Test with:
|
||||
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
|
||||
```
|
||||
|
||||
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
|
||||
|
||||
|
||||
### NixOS Module
|
||||
|
||||
Add the flake as an input and enable the service:
|
||||
|
||||
@@ -13,14 +13,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/server"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"git.t-juice.club/torjus/oubliette/internal/web"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/server"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/web"
|
||||
)
|
||||
|
||||
const Version = "0.10.0"
|
||||
const Version = "0.18.0"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
@@ -76,12 +76,13 @@ func run() error {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
m := metrics.New(Version)
|
||||
m.RegisterStoreCollector(store)
|
||||
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
|
||||
m.RegisterStoreCollector(instrumentedStore)
|
||||
|
||||
srv, err := server.New(*cfg, store, logger, m)
|
||||
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
srv, err := server.New(*cfg, instrumentedStore, logger, m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create server: %w", err)
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func run() error {
|
||||
metricsHandler = m.Handler()
|
||||
}
|
||||
|
||||
webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create web server: %w", err)
|
||||
}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -1,4 +1,4 @@
|
||||
module git.t-juice.club/torjus/oubliette
|
||||
module code.t-juice.club/torjus/oubliette
|
||||
|
||||
go 1.25.5
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.45.0
|
||||
)
|
||||
@@ -33,7 +34,6 @@ require (
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
|
||||
|
||||
@@ -28,10 +28,11 @@ type WebConfig struct {
|
||||
}
|
||||
|
||||
type ShellConfig struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
UsernameRoutes map[string]string `toml:"username_routes"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
@@ -165,9 +166,10 @@ func applyDefaults(cfg *Config) {
|
||||
|
||||
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||
var knownShellKeys = map[string]bool{
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"username_routes": true,
|
||||
}
|
||||
|
||||
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||
|
||||
@@ -313,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadUsernameRoutes(t *testing.T) {
|
||||
content := `
|
||||
[shell]
|
||||
hostname = "myhost"
|
||||
|
||||
[shell.username_routes]
|
||||
postgres = "psql"
|
||||
admin = "bash"
|
||||
|
||||
[shell.bash]
|
||||
custom_key = "value"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes == nil {
|
||||
t.Fatal("UsernameRoutes should not be nil")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
|
||||
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
|
||||
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
|
||||
}
|
||||
// username_routes should NOT appear in the Shells map.
|
||||
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
|
||||
t.Error("username_routes should not appear in Shells map")
|
||||
}
|
||||
// bash should still appear in Shells map.
|
||||
if _, ok := cfg.Shell.Shells["bash"]; !ok {
|
||||
t.Error("Shells[\"bash\"] should still be present")
|
||||
}
|
||||
}
|
||||
|
||||
func writeTemp(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "config.toml")
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -23,7 +23,10 @@ type Metrics struct {
|
||||
SessionsTotal *prometheus.CounterVec
|
||||
SessionsActive prometheus.Gauge
|
||||
SessionDuration prometheus.Histogram
|
||||
ExecCommandsTotal prometheus.Counter
|
||||
BuildInfo *prometheus.GaugeVec
|
||||
StorageQueryDuration *prometheus.HistogramVec
|
||||
StorageQueryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// New creates a new Metrics instance with all collectors registered.
|
||||
@@ -70,10 +73,23 @@ func New(version string) *Metrics {
|
||||
Help: "Session duration in seconds.",
|
||||
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
|
||||
}),
|
||||
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "oubliette_exec_commands_total",
|
||||
Help: "Total SSH exec commands received.",
|
||||
}),
|
||||
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "oubliette_build_info",
|
||||
Help: "Build information. Always 1.",
|
||||
}, []string{"version"}),
|
||||
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "oubliette_storage_query_duration_seconds",
|
||||
Help: "Duration of storage query calls in seconds.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"method"}),
|
||||
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_storage_query_errors_total",
|
||||
Help: "Total storage query errors.",
|
||||
}, []string{"method"}),
|
||||
}
|
||||
|
||||
reg.MustRegister(
|
||||
@@ -88,7 +104,10 @@ func New(version string) *Metrics {
|
||||
m.SessionsTotal,
|
||||
m.SessionsActive,
|
||||
m.SessionDuration,
|
||||
m.ExecCommandsTotal,
|
||||
m.BuildInfo,
|
||||
m.StorageQueryDuration,
|
||||
m.StorageQueryErrors,
|
||||
)
|
||||
|
||||
m.BuildInfo.WithLabelValues(version).Set(1)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
// Event types.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func testSession() SessionInfo {
|
||||
|
||||
@@ -12,19 +12,22 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"git.t-juice.club/torjus/oubliette/internal/geoip"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"code.t-juice.club/torjus/oubliette/internal/geoip"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@@ -58,6 +61,15 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics
|
||||
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering roomba shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering tetris shell: %w", err)
|
||||
}
|
||||
|
||||
geo, err := geoip.New()
|
||||
if err != nil {
|
||||
@@ -185,6 +197,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
||||
}
|
||||
}
|
||||
// Second priority: username-based route.
|
||||
if selectedShell == nil {
|
||||
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
||||
sh, found := s.shellRegistry.Get(shellName)
|
||||
if found {
|
||||
selectedShell = sh
|
||||
} else {
|
||||
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
||||
}
|
||||
}
|
||||
}
|
||||
// Lowest priority: random selection.
|
||||
if selectedShell == nil {
|
||||
var err error
|
||||
selectedShell, err = s.shellRegistry.Select()
|
||||
@@ -231,14 +255,24 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
|
||||
defer s.notifier.CleanupSession(sessionID)
|
||||
|
||||
// Handle session requests (pty-req, shell, etc.)
|
||||
// Handle session requests (pty-req, shell, exec, etc.)
|
||||
execCh := make(chan string, 1)
|
||||
go func() {
|
||||
defer close(execCh)
|
||||
for req := range requests {
|
||||
switch req.Type {
|
||||
case "pty-req", "shell":
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
case "exec":
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
var payload struct{ Command string }
|
||||
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
|
||||
execCh <- payload.Command
|
||||
}
|
||||
default:
|
||||
if req.WantReply {
|
||||
req.Reply(false, nil)
|
||||
@@ -247,6 +281,29 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for exec request before proceeding to interactive shell.
|
||||
select {
|
||||
case cmd, ok := <-execCh:
|
||||
if ok && cmd != "" {
|
||||
s.logger.Info("exec command received",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
"command", cmd,
|
||||
)
|
||||
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
|
||||
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
|
||||
}
|
||||
s.metrics.ExecCommandsTotal.Inc()
|
||||
// Send exit-status 0 and close channel.
|
||||
exitPayload := make([]byte, 4) // uint32(0)
|
||||
_, _ = channel.SendRequest("exit-status", false, exitPayload)
|
||||
return
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
// No exec request within timeout — proceed with interactive shell.
|
||||
}
|
||||
|
||||
// Build session context.
|
||||
var shellCfg map[string]any
|
||||
if s.cfg.Shell.Shells != nil {
|
||||
|
||||
@@ -11,9 +11,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@@ -252,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Test exec command capture.
|
||||
t.Run("exec_command", func(t *testing.T) {
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "root",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Run a command via exec (no PTY, no shell).
|
||||
if err := session.Run("uname -a"); err != nil {
|
||||
// Run returns an error because the server closes the channel,
|
||||
// but that's expected.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Give the server a moment to store the command.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify the exec command was captured.
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundExec bool
|
||||
for _, s := range sessions {
|
||||
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
|
||||
foundExec = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExec {
|
||||
t.Error("expected a session with exec_command='uname -a'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test username route: add username_routes so that "postgres" gets psql shell.
|
||||
t.Run("username_route", func(t *testing.T) {
|
||||
// Reconfigure with username routes.
|
||||
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
||||
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
||||
|
||||
// Need to get the "postgres" user in via static creds or threshold.
|
||||
// Use static creds for simplicity.
|
||||
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
||||
config.Credential{Username: "postgres", Password: "postgres"},
|
||||
)
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
defer func() {
|
||||
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
}()
|
||||
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "postgres",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("request pty: %v", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdin pipe: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
session.Stdout = &output
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("shell: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the psql banner.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Send \q to quit.
|
||||
stdin.Write([]byte(`\q` + "\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
session.Wait()
|
||||
|
||||
out := output.String()
|
||||
if !strings.Contains(out, "psql") {
|
||||
t.Errorf("output should contain psql banner, got: %s", out)
|
||||
}
|
||||
|
||||
// Verify session was created with shell name "psql".
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundPsql bool
|
||||
for _, s := range sessions {
|
||||
if s.ShellName == "psql" && s.Username == "postgres" {
|
||||
foundPsql = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundPsql {
|
||||
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test threshold acceptance: after enough failed dials, a subsequent
|
||||
// dial with the same credentials should succeed via threshold or
|
||||
// remembered credential.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// EventRecorder buffers I/O events in memory and periodically flushes them to
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestEventRecorderFlush(t *testing.T) {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
|
||||
123
internal/shell/psql/commands.go
Normal file
123
internal/shell/psql/commands.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// commandResult holds the output of a command and whether the session should end.
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
// dispatchBackslash handles psql backslash meta-commands.
|
||||
func dispatchBackslash(cmd, dbName string) commandResult {
|
||||
// Normalize: trim spaces after the backslash command word.
|
||||
parts := strings.Fields(cmd)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{output: "Invalid command \\. Try \\? for help."}
|
||||
}
|
||||
|
||||
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
|
||||
args := parts[1:]
|
||||
|
||||
switch verb {
|
||||
case `\q`:
|
||||
return commandResult{exit: true}
|
||||
case `\dt`:
|
||||
return commandResult{output: listTables()}
|
||||
case `\d`:
|
||||
if len(args) == 0 {
|
||||
return commandResult{output: listTables()}
|
||||
}
|
||||
return commandResult{output: describeTable(args[0])}
|
||||
case `\l`:
|
||||
return commandResult{output: listDatabases()}
|
||||
case `\du`:
|
||||
return commandResult{output: listRoles()}
|
||||
case `\conninfo`:
|
||||
return commandResult{output: connInfo(dbName)}
|
||||
case `\?`:
|
||||
return commandResult{output: backslashHelp()}
|
||||
case `\h`:
|
||||
return commandResult{output: sqlHelp()}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
|
||||
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
|
||||
// Strip trailing semicolon and whitespace for matching.
|
||||
trimmed := strings.TrimRight(sql, "; \t")
|
||||
trimmed = strings.TrimSpace(trimmed)
|
||||
upper := strings.ToUpper(trimmed)
|
||||
|
||||
switch {
|
||||
case upper == "SELECT VERSION()":
|
||||
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
|
||||
return commandResult{output: formatSingleValue("version", ver)}
|
||||
case upper == "SELECT CURRENT_DATABASE()":
|
||||
return commandResult{output: formatSingleValue("current_database", dbName)}
|
||||
case upper == "SELECT CURRENT_USER":
|
||||
return commandResult{output: formatSingleValue("current_user", "postgres")}
|
||||
case upper == "SELECT NOW()":
|
||||
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
|
||||
return commandResult{output: formatSingleValue("now", now)}
|
||||
case upper == "SELECT 1":
|
||||
return commandResult{output: formatSingleValue("?column?", "1")}
|
||||
case strings.HasPrefix(upper, "INSERT"):
|
||||
return commandResult{output: "INSERT 0 1"}
|
||||
case strings.HasPrefix(upper, "UPDATE"):
|
||||
return commandResult{output: "UPDATE 1"}
|
||||
case strings.HasPrefix(upper, "DELETE"):
|
||||
return commandResult{output: "DELETE 1"}
|
||||
case strings.HasPrefix(upper, "CREATE TABLE"):
|
||||
return commandResult{output: "CREATE TABLE"}
|
||||
case strings.HasPrefix(upper, "CREATE DATABASE"):
|
||||
return commandResult{output: "CREATE DATABASE"}
|
||||
case strings.HasPrefix(upper, "DROP TABLE"):
|
||||
return commandResult{output: "DROP TABLE"}
|
||||
case strings.HasPrefix(upper, "ALTER TABLE"):
|
||||
return commandResult{output: "ALTER TABLE"}
|
||||
case upper == "BEGIN":
|
||||
return commandResult{output: "BEGIN"}
|
||||
case upper == "COMMIT":
|
||||
return commandResult{output: "COMMIT"}
|
||||
case upper == "ROLLBACK":
|
||||
return commandResult{output: "ROLLBACK"}
|
||||
case upper == "SHOW SERVER_VERSION":
|
||||
return commandResult{output: formatSingleValue("server_version", pgVersion)}
|
||||
case upper == "SHOW SEARCH_PATH":
|
||||
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
|
||||
case strings.HasPrefix(upper, "SET "):
|
||||
return commandResult{output: "SET"}
|
||||
default:
|
||||
// Extract the first token for the error message.
|
||||
firstToken := strings.Fields(trimmed)
|
||||
token := trimmed
|
||||
if len(firstToken) > 0 {
|
||||
token = firstToken[0]
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
|
||||
}
|
||||
}
|
||||
|
||||
// formatSingleValue formats a single-row, single-column psql result.
|
||||
func formatSingleValue(colName, value string) string {
|
||||
width := max(len(colName), len(value))
|
||||
|
||||
var b strings.Builder
|
||||
// Header
|
||||
fmt.Fprintf(&b, " %-*s \n", width, colName)
|
||||
// Separator
|
||||
b.WriteString(strings.Repeat("-", width+2))
|
||||
b.WriteString("\n")
|
||||
// Value
|
||||
fmt.Fprintf(&b, " %-*s\n", width, value)
|
||||
// Row count
|
||||
b.WriteString("(1 row)")
|
||||
return b.String()
|
||||
}
|
||||
155
internal/shell/psql/output.go
Normal file
155
internal/shell/psql/output.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package psql
|
||||
|
||||
import "fmt"
|
||||
|
||||
func startupBanner(version string) string {
|
||||
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
|
||||
}
|
||||
|
||||
func listTables() string {
|
||||
return ` List of relations
|
||||
Schema | Name | Type | Owner
|
||||
--------+---------------+-------+----------
|
||||
public | audit_log | table | postgres
|
||||
public | credentials | table | postgres
|
||||
public | sessions | table | postgres
|
||||
public | users | table | postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listDatabases() string {
|
||||
return ` List of databases
|
||||
Name | Owner | Encoding | Collate | Ctype | Access privileges
|
||||
-----------+----------+----------+-------------+-------------+-----------------------
|
||||
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listRoles() string {
|
||||
return ` List of roles
|
||||
Role name | Attributes | Member of
|
||||
-----------+------------------------------------------------------------+-----------
|
||||
app_user | | {}
|
||||
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
|
||||
readonly | Cannot login | {}`
|
||||
}
|
||||
|
||||
func describeTable(name string) string {
|
||||
switch name {
|
||||
case "users":
|
||||
return ` Table "public.users"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+-----------------------------------
|
||||
id | integer | | not null | nextval('users_id_seq'::regclass)
|
||||
username | character varying(255) | | not null |
|
||||
email | character varying(255) | | not null |
|
||||
password | character varying(255) | | not null |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
updated_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"users_pkey" PRIMARY KEY, btree (id)
|
||||
"users_email_key" UNIQUE, btree (email)
|
||||
"users_username_key" UNIQUE, btree (username)`
|
||||
case "sessions":
|
||||
return ` Table "public.sessions"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+--------------------------------------
|
||||
id | integer | | not null | nextval('sessions_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
token | character varying(255) | | not null |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
expires_at | timestamp without time zone | | not null |
|
||||
Indexes:
|
||||
"sessions_pkey" PRIMARY KEY, btree (id)
|
||||
"sessions_token_key" UNIQUE, btree (token)
|
||||
Foreign-key constraints:
|
||||
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "credentials":
|
||||
return ` Table "public.credentials"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
-----------+-----------------------------+-----------+----------+-----------------------------------------
|
||||
id | integer | | not null | nextval('credentials_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
type | character varying(50) | | not null |
|
||||
value | text | | not null |
|
||||
created_at| timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"credentials_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "audit_log":
|
||||
return ` Table "public.audit_log"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+---------------------------------------
|
||||
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
action | character varying(100) | | not null |
|
||||
details | text | | |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"audit_log_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
default:
|
||||
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
|
||||
}
|
||||
}
|
||||
|
||||
func connInfo(dbName string) string {
|
||||
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
|
||||
}
|
||||
|
||||
func backslashHelp() string {
|
||||
return `General
|
||||
\copyright show PostgreSQL usage and distribution terms
|
||||
\crosstabview [COLUMNS] execute query and display result in crosstab
|
||||
\errverbose show most recent error message at maximum verbosity
|
||||
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
|
||||
\gdesc describe result of query, without executing it
|
||||
\gexec execute query, then execute each value in its result
|
||||
\gset [PREFIX] execute query and store result in psql variables
|
||||
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
|
||||
\q quit psql
|
||||
\watch [SEC] execute query every SEC seconds
|
||||
|
||||
Informational
|
||||
(options: S = show system objects, + = additional detail)
|
||||
\d[S+] list tables, views, and sequences
|
||||
\d[S+] NAME describe table, view, sequence, or index
|
||||
\da[S] [PATTERN] list aggregates
|
||||
\dA[+] [PATTERN] list access methods
|
||||
\dt[S+] [PATTERN] list tables
|
||||
\du[S+] [PATTERN] list roles
|
||||
\l[+] [PATTERN] list databases`
|
||||
}
|
||||
|
||||
func sqlHelp() string {
|
||||
return `Available help:
|
||||
ABORT CREATE LANGUAGE
|
||||
ALTER AGGREGATE CREATE MATERIALIZED VIEW
|
||||
ALTER COLLATION CREATE OPERATOR
|
||||
ALTER CONVERSION CREATE POLICY
|
||||
ALTER DATABASE CREATE PROCEDURE
|
||||
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
|
||||
ALTER DOMAIN CREATE ROLE
|
||||
ALTER EVENT TRIGGER CREATE RULE
|
||||
ALTER EXTENSION CREATE SCHEMA
|
||||
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
|
||||
ALTER FOREIGN TABLE CREATE SERVER
|
||||
ALTER FUNCTION CREATE STATISTICS
|
||||
ALTER GROUP CREATE SUBSCRIPTION
|
||||
ALTER INDEX CREATE TABLE
|
||||
ALTER LANGUAGE CREATE TABLESPACE
|
||||
BEGIN DELETE
|
||||
COMMIT DROP TABLE
|
||||
CREATE DATABASE INSERT
|
||||
CREATE INDEX ROLLBACK
|
||||
SELECT UPDATE`
|
||||
}
|
||||
137
internal/shell/psql/psql.go
Normal file
137
internal/shell/psql/psql.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// PsqlShell emulates a PostgreSQL psql interactive terminal.
|
||||
type PsqlShell struct{}
|
||||
|
||||
// NewPsqlShell returns a new PsqlShell instance.
|
||||
func NewPsqlShell() *PsqlShell {
|
||||
return &PsqlShell{}
|
||||
}
|
||||
|
||||
func (p *PsqlShell) Name() string { return "psql" }
|
||||
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
|
||||
|
||||
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dbName := configString(sess.ShellConfig, "db_name", "postgres")
|
||||
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
|
||||
|
||||
// Print startup banner.
|
||||
fmt.Fprint(rw, startupBanner(pgVersion))
|
||||
|
||||
var sqlBuf []string // accumulates multi-line SQL
|
||||
|
||||
for {
|
||||
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Empty line in non-buffering state: just re-prompt.
|
||||
if trimmed == "" && len(sqlBuf) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
|
||||
if strings.HasPrefix(trimmed, `\`) {
|
||||
sqlBuf = nil // discard any partial SQL
|
||||
|
||||
result := dispatchBackslash(trimmed, dbName)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate SQL lines.
|
||||
sqlBuf = append(sqlBuf, line)
|
||||
|
||||
// Check if the statement is terminated by a semicolon.
|
||||
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Full statement ready — join and dispatch.
|
||||
fullSQL := strings.Join(sqlBuf, " ")
|
||||
sqlBuf = nil
|
||||
|
||||
result := dispatchSQL(fullSQL, dbName, pgVersion)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
|
||||
func buildPrompt(dbName string, continuation bool) string {
|
||||
if continuation {
|
||||
return dbName + "-# "
|
||||
}
|
||||
return dbName + "=# "
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
330
internal/shell/psql/psql_test.go
Normal file
330
internal/shell/psql/psql_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Prompt tests ---
|
||||
|
||||
func TestBuildPromptNormal(t *testing.T) {
|
||||
got := buildPrompt("postgres", false)
|
||||
if got != "postgres=# " {
|
||||
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptContinuation(t *testing.T) {
|
||||
got := buildPrompt("postgres", true)
|
||||
if got != "postgres-# " {
|
||||
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptCustomDB(t *testing.T) {
|
||||
got := buildPrompt("mydb", false)
|
||||
if got != "mydb=# " {
|
||||
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Backslash command dispatch tests ---
|
||||
|
||||
func TestBackslashQuit(t *testing.T) {
|
||||
result := dispatchBackslash(`\q`, "postgres")
|
||||
if !result.exit {
|
||||
t.Error("\\q should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListTables(t *testing.T) {
|
||||
result := dispatchBackslash(`\dt`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\dt should list tables including 'users'")
|
||||
}
|
||||
if !strings.Contains(result.output, "sessions") {
|
||||
t.Error("\\dt should list tables including 'sessions'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d users`, "postgres")
|
||||
if !strings.Contains(result.output, "username") {
|
||||
t.Error("\\d users should describe users table with 'username' column")
|
||||
}
|
||||
if !strings.Contains(result.output, "PRIMARY KEY") {
|
||||
t.Error("\\d users should include index info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeUnknownTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d nonexistent`, "postgres")
|
||||
if !strings.Contains(result.output, "Did not find") {
|
||||
t.Error("\\d nonexistent should return not found message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListDatabases(t *testing.T) {
|
||||
result := dispatchBackslash(`\l`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\l should list databases including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "template0") {
|
||||
t.Error("\\l should list databases including 'template0'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListRoles(t *testing.T) {
|
||||
result := dispatchBackslash(`\du`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\du should list roles including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "Superuser") {
|
||||
t.Error("\\du should show Superuser attribute for postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashConnInfo(t *testing.T) {
|
||||
result := dispatchBackslash(`\conninfo`, "mydb")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("\\conninfo should include database name")
|
||||
}
|
||||
if !strings.Contains(result.output, "5432") {
|
||||
t.Error("\\conninfo should include port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\?`, "postgres")
|
||||
if !strings.Contains(result.output, `\q`) {
|
||||
t.Error("\\? should include \\q in help output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashSQLHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\h`, "postgres")
|
||||
if !strings.Contains(result.output, "SELECT") {
|
||||
t.Error("\\h should include SQL commands like SELECT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashUnknown(t *testing.T) {
|
||||
result := dispatchBackslash(`\xyz`, "postgres")
|
||||
if !strings.Contains(result.output, "Invalid command") {
|
||||
t.Error("unknown backslash command should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SQL dispatch tests ---
|
||||
|
||||
func TestSQLSelectVersion(t *testing.T) {
|
||||
result := dispatchSQL("SELECT version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SELECT version() should contain pg version")
|
||||
}
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT version() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentDatabase(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("SELECT current_database() should return db name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentUser(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("SELECT current_user should return postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectNow(t *testing.T) {
|
||||
result := dispatchSQL("SELECT now();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT now() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectOne(t *testing.T) {
|
||||
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "1") {
|
||||
t.Error("SELECT 1 should return 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLInsert(t *testing.T) {
|
||||
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
|
||||
if result.output != "INSERT 0 1" {
|
||||
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUpdate(t *testing.T) {
|
||||
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
|
||||
if result.output != "UPDATE 1" {
|
||||
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDelete(t *testing.T) {
|
||||
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
|
||||
if result.output != "DELETE 1" {
|
||||
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateTable(t *testing.T) {
|
||||
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
|
||||
if result.output != "CREATE TABLE" {
|
||||
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateDatabase(t *testing.T) {
|
||||
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
|
||||
if result.output != "CREATE DATABASE" {
|
||||
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDropTable(t *testing.T) {
|
||||
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
|
||||
if result.output != "DROP TABLE" {
|
||||
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLAlterTable(t *testing.T) {
|
||||
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
|
||||
if result.output != "ALTER TABLE" {
|
||||
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLBeginCommitRollback(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{"BEGIN;", "BEGIN"},
|
||||
{"COMMIT;", "COMMIT"},
|
||||
{"ROLLBACK;", "ROLLBACK"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := dispatchSQL(tt.sql, "postgres", "15.4")
|
||||
if result.output != tt.want {
|
||||
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowServerVersion(t *testing.T) {
|
||||
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SHOW server_version should contain version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowSearchPath(t *testing.T) {
|
||||
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "public") {
|
||||
t.Error("SHOW search_path should contain public")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSet(t *testing.T) {
|
||||
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
|
||||
if result.output != "SET" {
|
||||
t.Errorf("SET output = %q, want %q", result.output, "SET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUnrecognized(t *testing.T) {
|
||||
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "ERROR") {
|
||||
t.Error("unrecognized SQL should return error")
|
||||
}
|
||||
if !strings.Contains(result.output, "FOOBAR") {
|
||||
t.Error("error should reference the offending token")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Case insensitivity ---
|
||||
|
||||
func TestSQLCaseInsensitive(t *testing.T) {
|
||||
result := dispatchSQL("select version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("select version() (lowercase) should work")
|
||||
}
|
||||
|
||||
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("mixed case SELECT should work")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Startup banner ---
|
||||
|
||||
func TestStartupBanner(t *testing.T) {
|
||||
banner := startupBanner("15.4")
|
||||
if !strings.Contains(banner, "psql (15.4)") {
|
||||
t.Errorf("banner should contain version, got: %s", banner)
|
||||
}
|
||||
if !strings.Contains(banner, "help") {
|
||||
t.Error("banner should mention help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"db_name": "mydb"}
|
||||
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
|
||||
t.Errorf("configString() = %q, want %q", got, "mydb")
|
||||
}
|
||||
if got := configString(cfg, "missing", "default"); got != "default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "default")
|
||||
}
|
||||
if got := configString(nil, "key", "default"); got != "default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewPsqlShell()
|
||||
if s.Name() != "psql" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- formatSingleValue ---
|
||||
|
||||
func TestFormatSingleValue(t *testing.T) {
|
||||
out := formatSingleValue("?column?", "1")
|
||||
if !strings.Contains(out, "?column?") {
|
||||
t.Error("should contain column name")
|
||||
}
|
||||
if !strings.Contains(out, "1") {
|
||||
t.Error("should contain value")
|
||||
}
|
||||
if !strings.Contains(out, "(1 row)") {
|
||||
t.Error("should contain row count")
|
||||
}
|
||||
}
|
||||
|
||||
// --- \d with no args ---
|
||||
|
||||
func TestBackslashDescribeNoArgs(t *testing.T) {
|
||||
result := dispatchBackslash(`\d`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\d with no args should list tables")
|
||||
}
|
||||
}
|
||||
463
internal/shell/roomba/roomba.go
Normal file
463
internal/shell/roomba/roomba.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package roomba
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
|
||||
type RoombaShell struct{}
|
||||
|
||||
// NewRoombaShell returns a new RoombaShell instance.
|
||||
func NewRoombaShell() *RoombaShell {
|
||||
return &RoombaShell{}
|
||||
}
|
||||
|
||||
func (r *RoombaShell) Name() string { return "roomba" }
|
||||
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
|
||||
|
||||
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
state := newRoombaState()
|
||||
|
||||
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("roomba")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bootBanner() string {
|
||||
return `
|
||||
____ _ ___ ____
|
||||
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
|
||||
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
|
||||
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|
||||
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
|
||||
|
||||
iRobot Roomba j7+ | RoombaOS v4.3.7
|
||||
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
|
||||
Battery: 73% | WiFi: Connected (SmartHome-5G)
|
||||
|
||||
Type 'help' for available commands.
|
||||
|
||||
`
|
||||
}
|
||||
|
||||
type room struct {
|
||||
name string
|
||||
areaSqFt int
|
||||
lastCleaned time.Time
|
||||
}
|
||||
|
||||
type scheduleEntry struct {
|
||||
day string
|
||||
time string
|
||||
}
|
||||
|
||||
type historyEntry struct {
|
||||
timestamp time.Time
|
||||
room string
|
||||
duration string
|
||||
note string
|
||||
}
|
||||
|
||||
type roombaState struct {
|
||||
battery int
|
||||
dustbin int
|
||||
status string
|
||||
rooms []room
|
||||
schedule []scheduleEntry
|
||||
cleanHistory []historyEntry
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newRoombaState() *roombaState {
|
||||
now := time.Now()
|
||||
return &roombaState{
|
||||
battery: 73,
|
||||
dustbin: 61,
|
||||
status: "Docked",
|
||||
rooms: []room{
|
||||
{"Kitchen", 180, now.Add(-2 * time.Hour)},
|
||||
{"Living Room", 320, now.Add(-5 * time.Hour)},
|
||||
{"Bedroom", 200, now.Add(-26 * time.Hour)},
|
||||
{"Hallway", 60, now.Add(-5 * time.Hour)},
|
||||
{"Bathroom", 75, now.Add(-50 * time.Hour)},
|
||||
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
|
||||
},
|
||||
schedule: []scheduleEntry{
|
||||
{"Monday", "09:00"},
|
||||
{"Wednesday", "09:00"},
|
||||
{"Friday", "09:00"},
|
||||
{"Saturday", "14:00"},
|
||||
},
|
||||
cleanHistory: []historyEntry{
|
||||
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
|
||||
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
|
||||
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
|
||||
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
|
||||
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
|
||||
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) dispatch(input string) commandResult {
|
||||
parts := strings.Fields(input)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
cmd := strings.ToLower(parts[0])
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "help":
|
||||
return s.cmdHelp()
|
||||
case "status":
|
||||
return s.cmdStatus()
|
||||
case "clean":
|
||||
return s.cmdClean(args)
|
||||
case "dock":
|
||||
return s.cmdDock()
|
||||
case "map":
|
||||
return s.cmdMap()
|
||||
case "schedule":
|
||||
return s.cmdSchedule(args)
|
||||
case "history":
|
||||
return s.cmdHistory()
|
||||
case "diagnostics":
|
||||
return s.cmdDiagnostics()
|
||||
case "alerts":
|
||||
return s.cmdAlerts()
|
||||
case "reboot":
|
||||
return s.cmdReboot()
|
||||
case "exit", "logout":
|
||||
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
help - Show this help message
|
||||
status - Show robot status
|
||||
clean - Start full cleaning job
|
||||
clean room <name> - Clean a specific room
|
||||
dock - Return to dock
|
||||
map - Show floor plan and room list
|
||||
schedule - List cleaning schedule
|
||||
schedule add <day> <time> - Add scheduled cleaning
|
||||
schedule remove <day> - Remove scheduled cleaning
|
||||
history - Show recent cleaning history
|
||||
diagnostics - Run system diagnostics
|
||||
alerts - Show active alerts
|
||||
reboot - Reboot RoombaOS
|
||||
exit / logout - Disconnect`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdStatus() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== RoombaOS System Status ===\n")
|
||||
b.WriteString("Model: iRobot Roomba j7+\n")
|
||||
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
|
||||
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
|
||||
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
|
||||
b.WriteString("Side brush: OK (142 hrs)\n")
|
||||
b.WriteString("Main brush: OK (98 hrs)\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
|
||||
b.WriteString("Signal: -38 dBm\n")
|
||||
b.WriteString("Alexa: Linked\n")
|
||||
b.WriteString("Google Home: Linked\n")
|
||||
b.WriteString("iRobot Home App: Connected\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("Firmware: v4.3.7-stable\n")
|
||||
b.WriteString("LIDAR: Operational\n")
|
||||
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdClean(args []string) commandResult {
|
||||
if s.status == "Cleaning" {
|
||||
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
|
||||
}
|
||||
|
||||
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
|
||||
roomName := strings.Join(args[1:], " ")
|
||||
for _, r := range s.rooms {
|
||||
if strings.EqualFold(r.name, roomName) {
|
||||
s.status = "Cleaning"
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
|
||||
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
|
||||
)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
return commandResult{output: "Usage: clean [room <name>]"}
|
||||
}
|
||||
|
||||
s.status = "Cleaning"
|
||||
var totalArea int
|
||||
for _, r := range s.rooms {
|
||||
totalArea += r.areaSqFt
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
|
||||
totalArea, len(s.rooms), totalArea/8,
|
||||
)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDock() commandResult {
|
||||
if s.status == "Docked" {
|
||||
return commandResult{output: "Already docked."}
|
||||
}
|
||||
if s.status == "Returning to dock" {
|
||||
return commandResult{output: "Already returning to dock."}
|
||||
}
|
||||
s.status = "Returning to dock"
|
||||
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdMap() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Floor Plan ===\n\n")
|
||||
b.WriteString(" +------------+----------+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Kitchen | Bathroom |\n")
|
||||
b.WriteString(" | 180sqft | 75sqft |\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" +------+-----+----+-----+\n")
|
||||
b.WriteString(" | | | |\n")
|
||||
b.WriteString(" | Hall | Living | Cat |\n")
|
||||
b.WriteString(" | 60sf | Room | Rm |\n")
|
||||
b.WriteString(" | | 320sqft |110sf|\n")
|
||||
b.WriteString(" +------+ +-----+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Bed +----------+\n")
|
||||
b.WriteString(" | room | [DOCK]\n")
|
||||
b.WriteString(" |200sf |\n")
|
||||
b.WriteString(" +------+\n")
|
||||
b.WriteString("\nRoom Details:\n")
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
|
||||
for _, r := range s.rooms {
|
||||
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
|
||||
}
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdSchedule(args []string) commandResult {
|
||||
if len(args) == 0 {
|
||||
return s.scheduleList()
|
||||
}
|
||||
|
||||
sub := strings.ToLower(args[0])
|
||||
switch sub {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
|
||||
}
|
||||
return s.scheduleAdd(args[1], args[2])
|
||||
case "remove":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "Usage: schedule remove <day>"}
|
||||
}
|
||||
return s.scheduleRemove(args[1])
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleList() commandResult {
|
||||
if len(s.schedule) == 0 {
|
||||
return commandResult{output: "No cleaning schedule configured."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning Schedule ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
|
||||
for _, e := range s.schedule {
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleAdd(day, t string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
|
||||
if !slices.Contains(validDays, day) {
|
||||
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
|
||||
}
|
||||
|
||||
for _, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
|
||||
}
|
||||
}
|
||||
|
||||
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
|
||||
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleRemove(day string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
for i, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
|
||||
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHistory() commandResult {
|
||||
if len(s.cleanHistory) == 0 {
|
||||
return commandResult{output: "No cleaning history."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning History ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
|
||||
for _, h := range s.cleanHistory {
|
||||
ts := h.timestamp.Format("2006-01-02 15:04")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDiagnostics() commandResult {
|
||||
diag := `Running RoombaOS diagnostics...
|
||||
|
||||
[1/8] Cliff sensors........... OK
|
||||
[2/8] Bumper sensor........... OK
|
||||
[3/8] Side brush motor........ OK (142 hrs until replacement)
|
||||
[4/8] Main brush motor........ OK (98 hrs until replacement)
|
||||
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
|
||||
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
|
||||
[7/8] Dustbin sensor.......... OK
|
||||
[8/8] WiFi module............. OK (signal: -38 dBm)
|
||||
|
||||
ALL SYSTEMS NOMINAL`
|
||||
return commandResult{output: diag}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdAlerts() commandResult {
|
||||
var alerts []string
|
||||
if s.dustbin >= 60 {
|
||||
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
|
||||
}
|
||||
alerts = append(alerts,
|
||||
"WARNING: Side brush replacement due in 12 hours",
|
||||
"INFO: Unidentified sticky substance detected in Kitchen",
|
||||
"INFO: Cat frequently blocking cleaning path in Cat's Room",
|
||||
"INFO: Firmware update available: v4.4.0-beta",
|
||||
"INFO: Filter replacement recommended in 14 days",
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Active Alerts ===\n")
|
||||
for _, a := range alerts {
|
||||
b.WriteString(a + "\n")
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdReboot() commandResult {
|
||||
reboot := `RoombaOS is rebooting...
|
||||
|
||||
Stopping navigation engine..... done
|
||||
Saving room map data........... done
|
||||
Flushing cleaning logs......... done
|
||||
Disconnecting from WiFi........ done
|
||||
|
||||
Rebooting now. Goodbye!`
|
||||
return commandResult{output: reboot, exit: true}
|
||||
}
|
||||
|
||||
func capitalizeFirst(s string) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
hours := int(d.Hours())
|
||||
minutes := int(d.Minutes()) % 60
|
||||
if hours >= 24 {
|
||||
days := hours / 24
|
||||
hours %= 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||
|
||||
101
internal/shell/tetris/data.go
Normal file
101
internal/shell/tetris/data.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package tetris
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// pieceType identifies a tetromino (0–6).
|
||||
type pieceType int
|
||||
|
||||
const (
|
||||
pieceI pieceType = iota
|
||||
pieceO
|
||||
pieceT
|
||||
pieceS
|
||||
pieceZ
|
||||
pieceJ
|
||||
pieceL
|
||||
)
|
||||
|
||||
const numPieceTypes = 7
|
||||
|
||||
// Standard Tetris colors.
|
||||
var pieceColors = [numPieceTypes]lipgloss.Color{
|
||||
lipgloss.Color("#00FFFF"), // I — cyan
|
||||
lipgloss.Color("#FFFF00"), // O — yellow
|
||||
lipgloss.Color("#AA00FF"), // T — purple
|
||||
lipgloss.Color("#00FF00"), // S — green
|
||||
lipgloss.Color("#FF0000"), // Z — red
|
||||
lipgloss.Color("#0000FF"), // J — blue
|
||||
lipgloss.Color("#FF8800"), // L — orange
|
||||
}
|
||||
|
||||
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
|
||||
// relative to the piece origin.
|
||||
type rotation [4][2]int
|
||||
|
||||
var pieces = [numPieceTypes][4]rotation{
|
||||
// I
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
},
|
||||
// O
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
},
|
||||
// T
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// S
|
||||
{
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// Z
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
},
|
||||
// J
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
},
|
||||
// L
|
||||
{
|
||||
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
// spawnCol returns the starting column for a piece, centering it on the board.
|
||||
func spawnCol(pt pieceType, rot int) int {
|
||||
shape := pieces[pt][rot]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
width := maxC - minC + 1
|
||||
return (boardCols - width) / 2
|
||||
}
|
||||
210
internal/shell/tetris/game.go
Normal file
210
internal/shell/tetris/game.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tetris
|
||||
|
||||
import "math/rand/v2"
|
||||
|
||||
const (
|
||||
boardRows = 20
|
||||
boardCols = 10
|
||||
)
|
||||
|
||||
// cell represents a single board cell. Zero value is empty.
|
||||
type cell struct {
|
||||
filled bool
|
||||
piece pieceType // which piece type filled this cell (for color)
|
||||
}
|
||||
|
||||
// gameState holds all mutable state for a Tetris game.
|
||||
type gameState struct {
|
||||
board [boardRows][boardCols]cell
|
||||
current pieceType
|
||||
currentRot int
|
||||
currentRow int
|
||||
currentCol int
|
||||
next pieceType
|
||||
score int
|
||||
level int
|
||||
lines int
|
||||
gameOver bool
|
||||
}
|
||||
|
||||
// newGame creates a new game state, optionally starting at a given level.
|
||||
func newGame(startLevel int) *gameState {
|
||||
g := &gameState{
|
||||
level: startLevel,
|
||||
next: pieceType(rand.IntN(numPieceTypes)),
|
||||
}
|
||||
g.spawnPiece()
|
||||
return g
|
||||
}
|
||||
|
||||
// spawnPiece pulls the next piece and generates a new next.
|
||||
func (g *gameState) spawnPiece() {
|
||||
g.current = g.next
|
||||
g.next = pieceType(rand.IntN(numPieceTypes))
|
||||
g.currentRot = 0
|
||||
g.currentRow = 0
|
||||
g.currentCol = spawnCol(g.current, 0)
|
||||
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
g.gameOver = true
|
||||
}
|
||||
}
|
||||
|
||||
// canPlace checks whether the piece fits at the given position.
|
||||
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
|
||||
return false
|
||||
}
|
||||
if g.board[r][c].filled {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// moveLeft moves the current piece left if possible.
|
||||
func (g *gameState) moveLeft() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
|
||||
g.currentCol--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveRight moves the current piece right if possible.
|
||||
func (g *gameState) moveRight() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
|
||||
g.currentCol++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveDown moves the current piece down one row. Returns false if it cannot.
|
||||
func (g *gameState) moveDown() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
|
||||
g.currentRow++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rotate rotates the current piece clockwise with wall kick attempts.
|
||||
func (g *gameState) rotate() bool {
|
||||
newRot := (g.currentRot + 1) % 4
|
||||
|
||||
// Try in-place first.
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
|
||||
g.currentRot = newRot
|
||||
return true
|
||||
}
|
||||
|
||||
// Wall kick: try +-1 column offset.
|
||||
for _, offset := range []int{-1, 1} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// I piece: try +-2.
|
||||
if g.current == pieceI {
|
||||
for _, offset := range []int{-2, 2} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ghostRow returns the row where the piece would land.
|
||||
func (g *gameState) ghostRow() int {
|
||||
row := g.currentRow
|
||||
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
|
||||
row++
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
|
||||
func (g *gameState) hardDrop() int {
|
||||
ghost := g.ghostRow()
|
||||
dropped := ghost - g.currentRow
|
||||
g.currentRow = ghost
|
||||
return dropped
|
||||
}
|
||||
|
||||
// lockPiece writes the current piece into the board.
|
||||
func (g *gameState) lockPiece() {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearLines removes completed rows and returns how many were cleared.
|
||||
func (g *gameState) clearLines() int {
|
||||
cleared := 0
|
||||
for r := boardRows - 1; r >= 0; r-- {
|
||||
full := true
|
||||
for c := range boardCols {
|
||||
if !g.board[r][c].filled {
|
||||
full = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if full {
|
||||
cleared++
|
||||
// Shift everything above down.
|
||||
for rr := r; rr > 0; rr-- {
|
||||
g.board[rr] = g.board[rr-1]
|
||||
}
|
||||
g.board[0] = [boardCols]cell{}
|
||||
r++ // re-check this row since we shifted
|
||||
}
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
// NES-style scoring multipliers per lines cleared.
|
||||
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
|
||||
|
||||
// addScore updates score, lines, and level after clearing rows.
|
||||
func (g *gameState) addScore(linesCleared int) {
|
||||
if linesCleared > 0 && linesCleared <= 4 {
|
||||
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
|
||||
}
|
||||
g.lines += linesCleared
|
||||
|
||||
// Level up every 10 lines.
|
||||
newLevel := g.lines / 10
|
||||
if newLevel > g.level {
|
||||
g.level = newLevel
|
||||
}
|
||||
}
|
||||
|
||||
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
|
||||
// Returns the number of lines cleared.
|
||||
func (g *gameState) afterLock() int {
|
||||
g.lockPiece()
|
||||
cleared := g.clearLines()
|
||||
g.addScore(cleared)
|
||||
g.spawnPiece()
|
||||
return cleared
|
||||
}
|
||||
|
||||
// tickInterval returns the gravity interval in milliseconds for the current level.
|
||||
func tickInterval(level int) int {
|
||||
return max(800-level*60, 100)
|
||||
}
|
||||
331
internal/shell/tetris/model.go
Normal file
331
internal/shell/tetris/model.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
const (
|
||||
screenTitle screen = iota
|
||||
screenGame
|
||||
screenGameOver
|
||||
)
|
||||
|
||||
type tickMsg time.Time
|
||||
type lockMsg time.Time
|
||||
|
||||
const lockDelay = 500 * time.Millisecond
|
||||
|
||||
type model struct {
|
||||
sess *shell.SessionContext
|
||||
difficulty string
|
||||
screen screen
|
||||
game *gameState
|
||||
quitting bool
|
||||
height int
|
||||
keypresses int
|
||||
locking bool // true when piece has landed and lock delay is active
|
||||
}
|
||||
|
||||
func newModel(sess *shell.SessionContext, difficulty string) *model {
|
||||
return &model{
|
||||
sess: sess,
|
||||
difficulty: difficulty,
|
||||
screen: screenTitle,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.quitting {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
case tea.KeyMsg:
|
||||
m.keypresses++
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
return m.updateTitle(msg)
|
||||
case screenGame:
|
||||
return m.updateGame(msg)
|
||||
case screenGameOver:
|
||||
return m.updateGameOver(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
var content string
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
content = m.titleView()
|
||||
case screenGame:
|
||||
content = gameView(m.game)
|
||||
case screenGameOver:
|
||||
content = m.gameOverView()
|
||||
}
|
||||
|
||||
return gameFrame(content, m.height)
|
||||
}
|
||||
|
||||
// --- Title screen ---
|
||||
|
||||
func (m *model) titleView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" Press any key to start"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if _, ok := msg.(tea.KeyMsg); ok {
|
||||
m.screen = screenGame
|
||||
var startLevel int
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game screen ---
|
||||
|
||||
func (m *model) scheduleTick() tea.Cmd {
|
||||
ms := tickInterval(m.game.level)
|
||||
if m.difficulty == "easy" {
|
||||
ms = max(1000-m.game.level*60, 150)
|
||||
}
|
||||
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *model) scheduleLock() tea.Cmd {
|
||||
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
|
||||
return lockMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
// performLock locks the piece, clears lines, and returns commands for logging
|
||||
// and scheduling the next tick. Returns nil if game over (goToGameOver is
|
||||
// included in the returned batch).
|
||||
func (m *model) performLock() tea.Cmd {
|
||||
m.locking = false
|
||||
cleared := m.game.afterLock()
|
||||
if m.game.gameOver {
|
||||
return tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
|
||||
m.goToGameOver(),
|
||||
)
|
||||
}
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, m.scheduleTick())
|
||||
if cleared > 0 {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
|
||||
prevLevel := (m.game.lines - cleared) / 10
|
||||
if m.game.level > prevLevel {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
|
||||
}
|
||||
}
|
||||
return tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case lockMsg:
|
||||
if m.game.gameOver || !m.locking {
|
||||
return m, nil
|
||||
}
|
||||
// Lock delay expired — lock the piece now.
|
||||
return m, m.performLock()
|
||||
|
||||
case tickMsg:
|
||||
if m.game.gameOver || m.locking {
|
||||
return m, nil
|
||||
}
|
||||
if !m.game.moveDown() {
|
||||
// Piece landed — start lock delay instead of locking immediately.
|
||||
m.locking = true
|
||||
return m, m.scheduleLock()
|
||||
}
|
||||
return m, m.scheduleTick()
|
||||
|
||||
case tea.KeyMsg:
|
||||
if m.game.gameOver {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "left":
|
||||
m.game.moveLeft()
|
||||
// If piece can now drop further, cancel lock delay.
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "right":
|
||||
m.game.moveRight()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "down":
|
||||
if m.game.moveDown() {
|
||||
m.game.score++ // soft drop bonus
|
||||
if m.locking {
|
||||
m.locking = false
|
||||
}
|
||||
}
|
||||
case "up", "z":
|
||||
m.game.rotate()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case " ":
|
||||
m.locking = false
|
||||
dropped := m.game.hardDrop()
|
||||
m.game.score += dropped * 2
|
||||
return m, m.performLock()
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game over screen ---
|
||||
|
||||
func (m *model) goToGameOver() tea.Cmd {
|
||||
m.screen = screenGameOver
|
||||
return tea.ClearScreen
|
||||
}
|
||||
|
||||
func (m *model) gameOverView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" GAME OVER"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" R - Play again"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" Q - Quit"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok {
|
||||
switch keyMsg.String() {
|
||||
case "r":
|
||||
startLevel := 0
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
m.screen = screenGame
|
||||
m.keypresses = 0
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Helper methods for safe access when game may be nil.
|
||||
func (m *model) gameScore() int {
|
||||
if m.game != nil {
|
||||
return m.game.score
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLevel() int {
|
||||
if m.game != nil {
|
||||
return m.game.level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLines() int {
|
||||
if m.game != nil {
|
||||
return m.game.lines
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if sess.Store != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("tetris")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
286
internal/shell/tetris/style.go
Normal file
286
internal/shell/tetris/style.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const termWidth = 80
|
||||
|
||||
var (
|
||||
colorWhite = lipgloss.Color("#FFFFFF")
|
||||
colorDim = lipgloss.Color("#555555")
|
||||
colorBlack = lipgloss.Color("#000000")
|
||||
colorGhost = lipgloss.Color("#333333")
|
||||
)
|
||||
|
||||
var (
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack)
|
||||
|
||||
dimStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#00FFFF")).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
sidebarLabelStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
sidebarValueStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// cellStyle returns a style for a filled cell of a given piece type.
|
||||
func cellStyle(pt pieceType) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(pieceColors[pt]).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// ghostStyle returns a dimmed style for the ghost piece.
|
||||
func ghostCellStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(colorGhost).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// renderBoard renders the board, current piece, and ghost piece as a string.
|
||||
func renderBoard(g *gameState) string {
|
||||
// Build a display grid that includes the current piece and ghost.
|
||||
type displayCell struct {
|
||||
filled bool
|
||||
ghost bool
|
||||
piece pieceType
|
||||
}
|
||||
var grid [boardRows][boardCols]displayCell
|
||||
|
||||
// Copy locked cells.
|
||||
for r := range boardRows {
|
||||
for c := range boardCols {
|
||||
if g.board[r][c].filled {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ghost piece.
|
||||
ghostR := g.ghostRow()
|
||||
if ghostR != g.currentRow {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := ghostR+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
|
||||
grid[r][c] = displayCell{ghost: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Current piece.
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
|
||||
// Render grid.
|
||||
var b strings.Builder
|
||||
borderStyle := dimStyle
|
||||
|
||||
for _, row := range grid {
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
for _, dc := range row {
|
||||
switch {
|
||||
case dc.filled:
|
||||
b.WriteString(cellStyle(dc.piece).Render("[]"))
|
||||
case dc.ghost:
|
||||
b.WriteString(ghostCellStyle().Render("::"))
|
||||
default:
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderNextPiece renders the "next piece" preview box.
|
||||
func renderNextPiece(pt pieceType) string {
|
||||
shape := pieces[pt][0]
|
||||
// Determine bounding box.
|
||||
minR, maxR := shape[0][0], shape[0][0]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[0] < minR {
|
||||
minR = off[0]
|
||||
}
|
||||
if off[0] > maxR {
|
||||
maxR = off[0]
|
||||
}
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
|
||||
rows := maxR - minR + 1
|
||||
cols := maxC - minC + 1
|
||||
|
||||
// Build a small grid.
|
||||
grid := make([][]bool, rows)
|
||||
for i := range grid {
|
||||
grid[i] = make([]bool, cols)
|
||||
}
|
||||
for _, off := range shape {
|
||||
grid[off[0]-minR][off[1]-minC] = true
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
boxWidth := 8 // chars for the box interior
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
b.WriteString("\n")
|
||||
|
||||
for r := range rows {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
// Center the piece in the box.
|
||||
pieceWidth := cols * 2
|
||||
leftPad := (boxWidth - pieceWidth) / 2
|
||||
rightPad := boxWidth - pieceWidth - leftPad
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
|
||||
for c := range cols {
|
||||
if grid[r][c] {
|
||||
b.WriteString(cellStyle(pt).Render("[]"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Fill remaining rows in the box (max 4 rows for I piece).
|
||||
for r := rows; r < 2; r++ {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// formatScore formats a score with comma separators.
|
||||
func formatScore(n int) string {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
if len(s) <= 3 {
|
||||
return s
|
||||
}
|
||||
var parts []string
|
||||
for len(s) > 3 {
|
||||
parts = append([]string{s[len(s)-3:]}, parts...)
|
||||
s = s[:len(s)-3]
|
||||
}
|
||||
parts = append([]string{s}, parts...)
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// gameView combines the board and sidebar into the game screen.
|
||||
func gameView(g *gameState) string {
|
||||
boardStr := renderBoard(g)
|
||||
boardLines := strings.Split(boardStr, "\n")
|
||||
|
||||
nextStr := renderNextPiece(g.next)
|
||||
nextLines := strings.Split(nextStr, "\n")
|
||||
|
||||
// Build sidebar lines.
|
||||
var sidebar []string
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
|
||||
sidebar = append(sidebar, nextLines...)
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
|
||||
|
||||
// Combine board and sidebar side by side.
|
||||
var b strings.Builder
|
||||
maxLines := max(len(boardLines), len(sidebar))
|
||||
|
||||
for i := range maxLines {
|
||||
boardLine := ""
|
||||
if i < len(boardLines) {
|
||||
boardLine = boardLines[i]
|
||||
}
|
||||
sidebarLine := ""
|
||||
if i < len(sidebar) {
|
||||
sidebarLine = sidebar[i]
|
||||
}
|
||||
|
||||
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
|
||||
b.WriteString(boardLine)
|
||||
b.WriteString(sidebarLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// padLine pads a single line to termWidth.
|
||||
func padLine(line string) string {
|
||||
w := lipgloss.Width(line)
|
||||
if w >= termWidth {
|
||||
return line
|
||||
}
|
||||
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||
}
|
||||
|
||||
// padLines pads every line in a multi-line string to termWidth.
|
||||
func padLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = padLine(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// gameFrame wraps content with padding to fill the terminal.
|
||||
func gameFrame(content string, height int) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(content)
|
||||
|
||||
// Pad with blank lines to fill terminal height.
|
||||
if height > 0 {
|
||||
contentLines := strings.Count(content, "\n") + 1
|
||||
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||
for i := contentLines; i < height; i++ {
|
||||
b.WriteString(blankLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return padLines(b.String())
|
||||
}
|
||||
66
internal/shell/tetris/tetris.go
Normal file
66
internal/shell/tetris/tetris.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// TetrisShell is a Tetris game TUI for the honeypot.
|
||||
type TetrisShell struct{}
|
||||
|
||||
// NewTetrisShell returns a new TetrisShell instance.
|
||||
func NewTetrisShell() *TetrisShell {
|
||||
return &TetrisShell{}
|
||||
}
|
||||
|
||||
func (t *TetrisShell) Name() string { return "tetris" }
|
||||
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
|
||||
|
||||
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
|
||||
|
||||
m := newModel(sess, difficulty)
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(rw),
|
||||
tea.WithOutput(rw),
|
||||
tea.WithAltScreen(),
|
||||
)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := p.Run()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
p.Quit()
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
582
internal/shell/tetris/tetris_test.go
Normal file
582
internal/shell/tetris/tetris_test.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "player",
|
||||
Store: store,
|
||||
}
|
||||
m := newModel(sess, "normal")
|
||||
return m, store
|
||||
}
|
||||
|
||||
// sendKey sends a single key message to the model and returns the command.
|
||||
func sendKey(m *model, key string) tea.Cmd {
|
||||
var msg tea.KeyMsg
|
||||
switch key {
|
||||
case "enter":
|
||||
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||
case "up":
|
||||
msg = tea.KeyMsg{Type: tea.KeyUp}
|
||||
case "down":
|
||||
msg = tea.KeyMsg{Type: tea.KeyDown}
|
||||
case "left":
|
||||
msg = tea.KeyMsg{Type: tea.KeyLeft}
|
||||
case "right":
|
||||
msg = tea.KeyMsg{Type: tea.KeyRight}
|
||||
case "space":
|
||||
msg = tea.KeyMsg{Type: tea.KeySpace}
|
||||
case "ctrl+c":
|
||||
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||
default:
|
||||
if len(key) == 1 {
|
||||
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
|
||||
}
|
||||
}
|
||||
_, cmd := m.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// sendTick sends a tick message to the model.
|
||||
func sendTick(m *model) tea.Cmd {
|
||||
_, cmd := m.Update(tickMsg(time.Now()))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||
func execCmds(cmd tea.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
msg := cmd()
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
for _, c := range batch {
|
||||
execCmds(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTetrisShellName(t *testing.T) {
|
||||
sh := NewTetrisShell()
|
||||
if sh.Name() != "tetris" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"difficulty": "hard",
|
||||
}
|
||||
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
|
||||
t.Errorf("configString() = %q, want %q", got, "hard")
|
||||
}
|
||||
if got := configString(cfg, "missing", "normal"); got != "normal" {
|
||||
t.Errorf("configString() = %q, want %q", got, "normal")
|
||||
}
|
||||
if got := configString(nil, "difficulty", "normal"); got != "normal" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "normal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleScreenRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "████") {
|
||||
t.Error("title screen should show TETRIS logo")
|
||||
}
|
||||
if !strings.Contains(view, "Press any key") {
|
||||
t.Error("title screen should show 'Press any key'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleToGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
if m.screen != screenTitle {
|
||||
t.Fatalf("expected screenTitle, got %d", m.screen)
|
||||
}
|
||||
|
||||
sendKey(m, "enter")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after keypress, got %d", m.screen)
|
||||
}
|
||||
if m.game == nil {
|
||||
t.Fatal("game should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "|") {
|
||||
t.Error("game view should contain board borders")
|
||||
}
|
||||
if !strings.Contains(view, "SCORE") {
|
||||
t.Error("game view should show SCORE")
|
||||
}
|
||||
if !strings.Contains(view, "LEVEL") {
|
||||
t.Error("game view should show LEVEL")
|
||||
}
|
||||
if !strings.Contains(view, "LINES") {
|
||||
t.Error("game view should show LINES")
|
||||
}
|
||||
if !strings.Contains(view, "NEXT") {
|
||||
t.Error("game view should show NEXT")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Pure game logic tests ---
|
||||
|
||||
func TestNewGame(t *testing.T) {
|
||||
g := newGame(0)
|
||||
if g.gameOver {
|
||||
t.Error("new game should not be game over")
|
||||
}
|
||||
if g.score != 0 {
|
||||
t.Errorf("initial score = %d, want 0", g.score)
|
||||
}
|
||||
if g.level != 0 {
|
||||
t.Errorf("initial level = %d, want 0", g.level)
|
||||
}
|
||||
if g.lines != 0 {
|
||||
t.Errorf("initial lines = %d, want 0", g.lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGameHardLevel(t *testing.T) {
|
||||
g := newGame(5)
|
||||
if g.level != 5 {
|
||||
t.Errorf("hard start level = %d, want 5", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveLeft(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveLeft()
|
||||
if g.currentCol != startCol-1 {
|
||||
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveRight(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveRight()
|
||||
if g.currentCol != startCol+1 {
|
||||
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveDown(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
moved := g.moveDown()
|
||||
if !moved {
|
||||
t.Error("moveDown should succeed from starting position")
|
||||
}
|
||||
if g.currentRow != startRow+1 {
|
||||
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveLeftBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way left.
|
||||
for range boardCols {
|
||||
g.moveLeft()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveLeft() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveRightBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way right.
|
||||
for range boardCols {
|
||||
g.moveRight()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveRight() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRot := g.currentRot
|
||||
g.rotate()
|
||||
// Rotation should change (possibly with wall kick).
|
||||
if g.currentRot == startRot {
|
||||
// Rotation might legitimately fail in some edge cases, so just check
|
||||
// that the game state is valid.
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
t.Error("piece should be in a valid position after rotate attempt")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDrop(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
dropped := g.hardDrop()
|
||||
if dropped == 0 {
|
||||
t.Error("hard drop should move piece down at least some rows from top")
|
||||
}
|
||||
if g.currentRow <= startRow {
|
||||
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGhostRow(t *testing.T) {
|
||||
g := newGame(0)
|
||||
ghost := g.ghostRow()
|
||||
if ghost < g.currentRow {
|
||||
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
|
||||
}
|
||||
// Ghost should be at a position where moving down one more is impossible.
|
||||
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
|
||||
t.Error("ghost row should be the lowest valid position")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockPiece(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.hardDrop()
|
||||
pt := g.current
|
||||
row, col, rot := g.currentRow, g.currentCol, g.currentRot
|
||||
g.lockPiece()
|
||||
|
||||
// Verify that the piece's cells are now filled.
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if !g.board[r][c].filled {
|
||||
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom row completely.
|
||||
for c := range boardCols {
|
||||
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 1 {
|
||||
t.Errorf("clearLines() = %d, want 1", cleared)
|
||||
}
|
||||
// Bottom row should now be empty (shifted from above).
|
||||
for c := range boardCols {
|
||||
if g.board[boardRows-1][c].filled {
|
||||
t.Errorf("bottom row col %d should be empty after clearing", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearMultipleLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom 4 rows.
|
||||
for r := boardRows - 4; r < boardRows; r++ {
|
||||
for c := range boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 4 {
|
||||
t.Errorf("clearLines() = %d, want 4", cleared)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoring(t *testing.T) {
|
||||
tests := []struct {
|
||||
lines int
|
||||
level int
|
||||
want int
|
||||
}{
|
||||
{1, 0, 40},
|
||||
{2, 0, 100},
|
||||
{3, 0, 300},
|
||||
{4, 0, 1200},
|
||||
{1, 1, 80},
|
||||
{4, 2, 3600},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
g := newGame(tt.level)
|
||||
g.addScore(tt.lines)
|
||||
if g.score != tt.want {
|
||||
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLevelUp(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.lines = 9
|
||||
g.addScore(1) // This should push lines to 10, triggering level 1.
|
||||
if g.level != 1 {
|
||||
t.Errorf("level = %d, want 1 after 10 lines", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickInterval(t *testing.T) {
|
||||
if got := tickInterval(0); got != 800 {
|
||||
t.Errorf("tickInterval(0) = %d, want 800", got)
|
||||
}
|
||||
if got := tickInterval(5); got != 500 {
|
||||
t.Errorf("tickInterval(5) = %d, want 500", got)
|
||||
}
|
||||
// Floor at 100ms.
|
||||
if got := tickInterval(20); got != 100 {
|
||||
t.Errorf("tickInterval(20) = %d, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatScore(t *testing.T) {
|
||||
tests := []struct {
|
||||
n int
|
||||
want string
|
||||
}{
|
||||
{0, "0"},
|
||||
{100, "100"},
|
||||
{1250, "1,250"},
|
||||
{1000000, "1,000,000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := formatScore(tt.n); got != tt.want {
|
||||
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameOverScreen(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Force game over.
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "GAME OVER") {
|
||||
t.Error("game over screen should show GAME OVER")
|
||||
}
|
||||
if !strings.Contains(view, "Score") {
|
||||
t.Error("game over screen should show score")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestartFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "r")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after restart, got %d", m.screen)
|
||||
}
|
||||
if m.game.gameOver {
|
||||
t.Error("game should not be over after restart")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q in game over")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
scoreBefore := m.game.score
|
||||
sendKey(m, "down")
|
||||
if m.game.score != scoreBefore+1 {
|
||||
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Hard drop gives 2 points per row dropped.
|
||||
sendKey(m, "space")
|
||||
if m.game.score < 2 {
|
||||
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickMovesDown(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
rowBefore := m.game.currentRow
|
||||
sendTick(m)
|
||||
// Piece should either move down by 1, or lock and spawn a new piece at top.
|
||||
movedDown := m.game.currentRow == rowBefore+1
|
||||
respawned := m.game.currentRow < rowBefore
|
||||
if !movedDown && !respawned && !m.game.gameOver {
|
||||
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
m, store := newTestModel(t)
|
||||
|
||||
// Press key to start game — returns a logAction cmd.
|
||||
cmd := sendKey(m, "enter")
|
||||
if cmd != nil {
|
||||
execCmds(cmd)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
found := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "GAME START") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected GAME START in session logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeypressCounter(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "left")
|
||||
sendKey(m, "right")
|
||||
sendKey(m, "down")
|
||||
|
||||
if m.keypresses != 4 { // enter + 3 game keys
|
||||
t.Errorf("keypresses = %d, want 4", m.keypresses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelay(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Drop piece to the bottom via ticks until it can't move down.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
if !m.locking {
|
||||
t.Fatal("piece should be in locking state after hitting bottom")
|
||||
}
|
||||
|
||||
// During lock delay, we should still be able to move left/right.
|
||||
colBefore := m.game.currentCol
|
||||
sendKey(m, "left")
|
||||
if m.game.currentCol >= colBefore {
|
||||
// Might not have moved if against wall, try right.
|
||||
sendKey(m, "right")
|
||||
}
|
||||
|
||||
// Sending a lockMsg should finalize the piece.
|
||||
m.Update(lockMsg(time.Now()))
|
||||
// After lock, a new piece should have spawned (row near top).
|
||||
if m.game.currentRow > 1 && !m.game.gameOver {
|
||||
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelayCancelledByDrop(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Build a ledge: fill rows 18-19 but leave column 0 empty.
|
||||
for r := boardRows - 2; r < boardRows; r++ {
|
||||
for c := 1; c < boardCols; c++ {
|
||||
m.game.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
|
||||
// Move piece to column 0 area and drop it onto the ledge.
|
||||
for range boardCols {
|
||||
m.game.moveLeft()
|
||||
}
|
||||
// Tick down until locking.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
// If piece is on the ledge and we slide it to col 0 (open column),
|
||||
// the lock delay should cancel since it can fall further.
|
||||
// This test just validates the locking flag logic works.
|
||||
if m.locking {
|
||||
// Try moving — if piece can drop further, locking should cancel.
|
||||
sendKey(m, "left")
|
||||
// Whether locking cancels depends on the board state; just verify no crash.
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnCol(t *testing.T) {
|
||||
// All pieces should spawn roughly centered.
|
||||
for pt := range pieceType(numPieceTypes) {
|
||||
col := spawnCol(pt, 0)
|
||||
if col < 0 || col > boardCols-1 {
|
||||
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
|
||||
}
|
||||
// Verify piece fits at spawn position.
|
||||
shape := pieces[pt][0]
|
||||
for _, off := range shape {
|
||||
c := col + off[1]
|
||||
if c < 0 || c >= boardCols {
|
||||
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
217
internal/storage/instrumented.go
Normal file
217
internal/storage/instrumented.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// InstrumentedStore wraps a Store and records query duration and errors
|
||||
// as Prometheus metrics for each method call.
|
||||
type InstrumentedStore struct {
|
||||
store Store
|
||||
queryDuration *prometheus.HistogramVec
|
||||
queryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
|
||||
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
|
||||
return &InstrumentedStore{
|
||||
store: store,
|
||||
queryDuration: queryDuration,
|
||||
queryErrors: queryErrors,
|
||||
}
|
||||
}
|
||||
|
||||
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
v, err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
return observeErr(s, "RecordLoginAttempt", func() error {
|
||||
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
return observe(s, "CreateSession", func() (string, error) {
|
||||
return s.store.CreateSession(ctx, ip, username, shellName, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
||||
return observeErr(s, "EndSession", func() error {
|
||||
return s.store.EndSession(ctx, sessionID, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
||||
return observeErr(s, "UpdateHumanScore", func() error {
|
||||
return s.store.UpdateHumanScore(ctx, sessionID, score)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
return observeErr(s, "SetExecCommand", func() error {
|
||||
return s.store.SetExecCommand(ctx, sessionID, command)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
return observeErr(s, "AppendSessionLog", func() error {
|
||||
return s.store.AppendSessionLog(ctx, sessionID, input, output)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
|
||||
return s.store.DeleteRecordsBefore(ctx, cutoff)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetDashboardStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopUsernames(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopPasswords(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopIPs(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopCountries(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopExecCommands(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
return observe(s, "GetRecentSessions", func() ([]Session, error) {
|
||||
return s.store.GetRecentSessions(ctx, limit, activeOnly)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
|
||||
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
return observe(s, "GetSession", func() (*Session, error) {
|
||||
return s.store.GetSession(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
|
||||
return s.store.GetSessionLogs(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||
return observeErr(s, "AppendSessionEvents", func() error {
|
||||
return s.store.AppendSessionEvents(ctx, events)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
|
||||
return s.store.GetSessionEvents(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
return observe(s, "CloseActiveSessions", func() (int64, error) {
|
||||
return s.store.CloseActiveSessions(ctx, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
|
||||
return s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
|
||||
return s.store.GetHourlyPattern(ctx, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
|
||||
return s.store.GetCountryStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetFilteredDashboardStats(ctx, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopUsernames(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopPasswords(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopIPs(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopCountries(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
||||
163
internal/storage/instrumented_test.go
Normal file
163
internal/storage/instrumented_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
store := NewMemoryStore()
|
||||
return NewInstrumentedStore(store, dur, errs), dur, errs
|
||||
}
|
||||
|
||||
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
|
||||
m := &dto.Metric{}
|
||||
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
|
||||
return m.GetHistogram().GetSampleCount()
|
||||
}
|
||||
|
||||
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
|
||||
m := &dto.Metric{}
|
||||
c.WithLabelValues(method).Write(m)
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreDelegation(t *testing.T) {
|
||||
s, dur, _ := newTestInstrumented()
|
||||
ctx := context.Background()
|
||||
|
||||
// RecordLoginAttempt should delegate and record duration.
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// CreateSession should delegate and return a valid ID.
|
||||
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("CreateSession returned empty ID")
|
||||
}
|
||||
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// GetDashboardStats should delegate.
|
||||
stats, err := s.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDashboardStats: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Fatal("GetDashboardStats returned nil")
|
||||
}
|
||||
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreErrorCounting(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_ec_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_ec_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
// Error should be counted.
|
||||
err := s.EndSession(ctx, "nonexistent", time.Now())
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
if c := getCounterValue(errs, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
|
||||
// Successful call should not increment error counter.
|
||||
s2, _, errs2 := newTestInstrumented()
|
||||
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
|
||||
t.Fatalf("expected error count 0, got %f", c)
|
||||
}
|
||||
}
|
||||
|
||||
// errorStore is a Store that returns errors for all methods.
|
||||
type errorStore struct {
|
||||
MemoryStore
|
||||
}
|
||||
|
||||
var errFake = errors.New("fake error")
|
||||
|
||||
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreObserveErr(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test2_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test2_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreClose(t *testing.T) {
|
||||
s, _, _ := newTestInstrumented()
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -91,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if s, ok := m.Sessions[sessionID]; ok {
|
||||
s.ExecCommand = &command
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -326,20 +336,105 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, f), nil
|
||||
}
|
||||
|
||||
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
|
||||
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
|
||||
// Compute event counts and input bytes per session.
|
||||
eventCounts := make(map[string]int)
|
||||
inputBytes := make(map[string]int64)
|
||||
for _, e := range m.SessionEvents {
|
||||
eventCounts[e.SessionID]++
|
||||
if e.Direction == 0 {
|
||||
inputBytes[e.SessionID] += int64(len(e.Data))
|
||||
}
|
||||
}
|
||||
|
||||
var sessions []Session
|
||||
for _, s := range m.Sessions {
|
||||
if activeOnly && s.DisconnectedAt != nil {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *s)
|
||||
if !matchesSessionFilter(s, f) {
|
||||
continue
|
||||
}
|
||||
sess := *s
|
||||
sess.EventCount = eventCounts[s.ID]
|
||||
sess.InputBytes = inputBytes[s.ID]
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
|
||||
if f.SortBy == "input_bytes" {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].InputBytes > sessions[j].InputBytes
|
||||
})
|
||||
} else {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
if limit > 0 && len(sessions) > limit {
|
||||
sessions = sessions[:limit]
|
||||
}
|
||||
return sessions, nil
|
||||
return sessions
|
||||
}
|
||||
|
||||
// matchesSessionFilter returns true if the session matches the given filter.
|
||||
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && s.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
if s.HumanScore == nil || *s.HumanScore <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, s := range m.Sessions {
|
||||
if s.ExecCommand != nil {
|
||||
counts[*s.ExecCommand]++
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
@@ -357,6 +452,258 @@ func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var cutoff time.Time
|
||||
if since != nil {
|
||||
cutoff = *since
|
||||
} else {
|
||||
cutoff = time.Now().UTC().AddDate(0, 0, -days)
|
||||
}
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.LastSeen.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
if until != nil && a.LastSeen.After(*until) {
|
||||
continue
|
||||
}
|
||||
day := a.LastSeen.Format("2006-01-02")
|
||||
counts[day] += int64(a.Count)
|
||||
}
|
||||
|
||||
points := make([]TimeSeriesPoint, 0, len(counts))
|
||||
for day, count := range counts {
|
||||
t, _ := time.Parse("2006-01-02", day)
|
||||
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
|
||||
}
|
||||
sort.Slice(points, func(i, j int) bool {
|
||||
return points[i].Timestamp.Before(points[j].Timestamp)
|
||||
})
|
||||
return points, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
hourCounts := make(map[int]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if since != nil && a.LastSeen.Before(*since) {
|
||||
continue
|
||||
}
|
||||
if until != nil && a.LastSeen.After(*until) {
|
||||
continue
|
||||
}
|
||||
hour := a.LastSeen.Hour()
|
||||
hourCounts[hour] += int64(a.Count)
|
||||
}
|
||||
|
||||
counts := make([]HourlyCount, 0, len(hourCounts))
|
||||
for h, c := range hourCounts {
|
||||
counts = append(counts, HourlyCount{Hour: h, Count: c})
|
||||
}
|
||||
sort.Slice(counts, func(i, j int) bool {
|
||||
return counts[i].Hour < counts[j].Hour
|
||||
})
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
result := make([]CountryCount, 0, len(counts))
|
||||
for country, count := range counts {
|
||||
result = append(result, CountryCount{Country: country, Count: count})
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Count > result[j].Count
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
|
||||
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
|
||||
if f.Since != nil && a.LastSeen.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && a.LastSeen.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && a.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && a.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && a.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
stats := &DashboardStats{}
|
||||
ips := make(map[string]struct{})
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
stats.TotalAttempts += int64(a.Count)
|
||||
ips[a.IP] = struct{}{}
|
||||
}
|
||||
stats.UniqueIPs = int64(len(ips))
|
||||
|
||||
for _, s := range m.Sessions {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
continue
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
continue
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
continue
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
continue
|
||||
}
|
||||
stats.TotalSessions++
|
||||
if s.DisconnectedAt == nil {
|
||||
stats.ActiveSessions++
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.filteredTopN("username", limit, f), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.filteredTopN("password", limit, f), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
type ipInfo struct {
|
||||
count int64
|
||||
country string
|
||||
}
|
||||
agg := make(map[string]*ipInfo)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
info, ok := agg[a.IP]
|
||||
if !ok {
|
||||
info = &ipInfo{}
|
||||
agg[a.IP] = info
|
||||
}
|
||||
info.count += int64(a.Count)
|
||||
if a.Country != "" {
|
||||
info.country = a.Country
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(agg))
|
||||
for ip, info := range agg {
|
||||
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
|
||||
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
|
||||
counts := make(map[string]int64)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
var key string
|
||||
switch field {
|
||||
case "username":
|
||||
key = a.Username
|
||||
case "password":
|
||||
key = a.Password
|
||||
case "ip":
|
||||
key = a.IP
|
||||
}
|
||||
counts[key] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE sessions ADD COLUMN exec_command TEXT;
|
||||
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
|
||||
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
|
||||
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);
|
||||
@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 3 {
|
||||
t.Errorf("version = %d, want 3", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d, want 5", version)
|
||||
}
|
||||
|
||||
// Verify tables exist by inserting into them.
|
||||
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 3 {
|
||||
t.Errorf("version = %d after double migrate, want 3", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d after double migrate, want 5", version)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -83,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET exec_command = ? WHERE id = ?`,
|
||||
command, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting exec command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
@@ -100,12 +111,13 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
var execCommand sql.NullString
|
||||
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score
|
||||
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
|
||||
FROM sessions WHERE id = ?`, sessionID).Scan(
|
||||
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
|
||||
&connectedAt, &disconnectedAt, &humanScore,
|
||||
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -122,6 +134,9 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
if humanScore.Valid {
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
if execCommand.Valid {
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
@@ -368,40 +383,132 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
query := `SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score FROM sessions`
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
||||
if activeOnly {
|
||||
query += ` WHERE disconnected_at IS NULL`
|
||||
query += ` WHERE s.disconnected_at IS NULL`
|
||||
}
|
||||
query += ` ORDER BY connected_at DESC LIMIT ?`
|
||||
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, limit)
|
||||
return s.scanSessions(ctx, query, limit)
|
||||
}
|
||||
|
||||
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
|
||||
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if activeOnly {
|
||||
clauses = append(clauses, "s.disconnected_at IS NULL")
|
||||
}
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "s.connected_at >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "s.connected_at <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "s.ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "s.country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "s.username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
clauses = append(clauses, "s.human_score > 0")
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
|
||||
var validSessionSorts = map[string]string{
|
||||
"connected_at": "s.connected_at DESC",
|
||||
"input_bytes": "input_bytes DESC",
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
where, args := buildSessionWhereClause(f, activeOnly)
|
||||
args = append(args, limit)
|
||||
|
||||
orderBy := validSessionSorts["connected_at"]
|
||||
if mapped, ok := validSessionSorts[f.SortBy]; ok {
|
||||
orderBy = mapped
|
||||
}
|
||||
|
||||
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
|
||||
|
||||
return s.scanSessions(ctx, query, args...)
|
||||
}
|
||||
|
||||
// scanSessions executes a session query and scans the results.
|
||||
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying recent sessions: %w", err)
|
||||
return nil, fmt.Errorf("querying sessions: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var s Session
|
||||
var sess Session
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore); err != nil {
|
||||
var execCommand sql.NullString
|
||||
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
|
||||
return nil, fmt.Errorf("scanning session: %w", err)
|
||||
}
|
||||
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
if disconnectedAt.Valid {
|
||||
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||
s.DisconnectedAt = &t
|
||||
sess.DisconnectedAt = &t
|
||||
}
|
||||
if humanScore.Valid {
|
||||
s.HumanScore = &humanScore.Float64
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
if execCommand.Valid {
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT exec_command, COUNT(*) as total
|
||||
FROM sessions
|
||||
WHERE exec_command IS NOT NULL
|
||||
GROUP BY exec_command
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top exec commands: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top exec commands: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
|
||||
@@ -412,6 +519,265 @@ func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt ti
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
|
||||
var args []any
|
||||
|
||||
if since != nil {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, since.UTC().Format(time.RFC3339))
|
||||
} else {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
|
||||
}
|
||||
if until != nil {
|
||||
query += ` AND last_seen <= ?`
|
||||
args = append(args, until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
query += ` GROUP BY d ORDER BY d`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying attempts over time: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var points []TimeSeriesPoint
|
||||
for rows.Next() {
|
||||
var dateStr string
|
||||
var p TimeSeriesPoint
|
||||
if err := rows.Scan(&dateStr, &p.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning time series point: %w", err)
|
||||
}
|
||||
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
|
||||
points = append(points, p)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
|
||||
var args []any
|
||||
|
||||
if since != nil {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if until != nil {
|
||||
query += ` AND last_seen <= ?`
|
||||
args = append(args, until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
query += ` GROUP BY h ORDER BY h`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying hourly pattern: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var counts []HourlyCount
|
||||
for rows.Next() {
|
||||
var c HourlyCount
|
||||
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning hourly count: %w", err)
|
||||
}
|
||||
counts = append(counts, c)
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
WHERE country != ''
|
||||
GROUP BY country
|
||||
ORDER BY total DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying country stats: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var counts []CountryCount
|
||||
for rows.Next() {
|
||||
var c CountryCount
|
||||
if err := rows.Scan(&c.Country, &c.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning country count: %w", err)
|
||||
}
|
||||
counts = append(counts, c)
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
|
||||
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "last_seen >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "last_seen <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
stats := &DashboardStats{}
|
||||
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
|
||||
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
|
||||
}
|
||||
|
||||
// Sessions don't have username/password, so only filter by time, IP, country.
|
||||
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
|
||||
var sessArgs []any
|
||||
if f.Since != nil {
|
||||
sessQuery += ` AND connected_at >= ?`
|
||||
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
sessQuery += ` AND connected_at <= ?`
|
||||
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
sessQuery += ` AND ip = ?`
|
||||
sessArgs = append(sessArgs, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
sessQuery += ` AND country = ?`
|
||||
sessArgs = append(sessArgs, f.Country)
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return s.queryFilteredTopN(ctx, "username", limit, f)
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return s.queryFilteredTopN(ctx, "password", limit, f)
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
args = append(args, limit)
|
||||
//nolint:gosec // where clause built from trusted constants, not user input
|
||||
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
countryClause := "country != ''"
|
||||
if where == "" {
|
||||
where = " WHERE " + countryClause
|
||||
} else {
|
||||
where += " AND " + countryClause
|
||||
}
|
||||
args = append(args, limit)
|
||||
//nolint:gosec // where clause built from trusted constants, not user input
|
||||
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top countries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
switch column {
|
||||
case "username", "password":
|
||||
// valid columns
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid column: %s", column)
|
||||
}
|
||||
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
args = append(args, limit)
|
||||
query := fmt.Sprintf(`
|
||||
SELECT %s, SUM(count) AS total
|
||||
FROM login_attempts%s
|
||||
GROUP BY %s
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, column, where, column)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
@@ -204,6 +204,79 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTopExecCommands(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions with exec commands.
|
||||
for range 3 {
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
}
|
||||
// Session without exec command — should not appear.
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
entries, err := store.GetTopExecCommands(ctx, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTopExecCommands: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(entries))
|
||||
}
|
||||
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
|
||||
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
|
||||
}
|
||||
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
|
||||
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRecentSessionsEventCount(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
// Add some events.
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
|
||||
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("appending events: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].EventCount != 2 {
|
||||
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
store, err := NewSQLiteStore(dbPath)
|
||||
|
||||
@@ -27,6 +27,9 @@ type Session struct {
|
||||
ConnectedAt time.Time
|
||||
DisconnectedAt *time.Time
|
||||
HumanScore *float64
|
||||
ExecCommand *string
|
||||
EventCount int
|
||||
InputBytes int64
|
||||
}
|
||||
|
||||
// SessionLog represents a single log entry for a session.
|
||||
@@ -54,6 +57,35 @@ type DashboardStats struct {
|
||||
ActiveSessions int64
|
||||
}
|
||||
|
||||
// TimeSeriesPoint represents a single data point in a time series.
|
||||
type TimeSeriesPoint struct {
|
||||
Timestamp time.Time
|
||||
Count int64
|
||||
}
|
||||
|
||||
// HourlyCount represents the total attempts for a given hour of day.
|
||||
type HourlyCount struct {
|
||||
Hour int // 0-23
|
||||
Count int64
|
||||
}
|
||||
|
||||
// CountryCount represents the total attempts from a given country.
|
||||
type CountryCount struct {
|
||||
Country string
|
||||
Count int64
|
||||
}
|
||||
|
||||
// DashboardFilter contains optional filters for dashboard queries.
|
||||
type DashboardFilter struct {
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
HumanScoreAboveZero bool
|
||||
SortBy string
|
||||
}
|
||||
|
||||
// TopEntry represents a value and its count for top-N queries.
|
||||
type TopEntry struct {
|
||||
Value string
|
||||
@@ -76,6 +108,9 @@ type Store interface {
|
||||
// UpdateHumanScore sets the human detection score for a session.
|
||||
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
|
||||
|
||||
// SetExecCommand sets the exec command for a session.
|
||||
SetExecCommand(ctx context.Context, sessionID string, command string) error
|
||||
|
||||
// AppendSessionLog adds a log entry to a session.
|
||||
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
|
||||
|
||||
@@ -98,10 +133,17 @@ type Store interface {
|
||||
// GetTopCountries returns the top N countries by total attempt count.
|
||||
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetTopExecCommands returns the top N exec commands by session count.
|
||||
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
|
||||
// If activeOnly is true, only sessions with no disconnected_at are returned.
|
||||
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
|
||||
|
||||
// GetFilteredSessions returns sessions matching the given filter, ordered
|
||||
// by the filter's SortBy field (default: connected_at DESC).
|
||||
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
|
||||
|
||||
// GetSession returns a single session by ID.
|
||||
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||||
|
||||
@@ -119,6 +161,30 @@ type Store interface {
|
||||
// sessions left over from a previous unclean shutdown.
|
||||
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
|
||||
|
||||
// GetAttemptsOverTime returns daily attempt counts for the last N days.
|
||||
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
|
||||
|
||||
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
|
||||
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
|
||||
|
||||
// GetCountryStats returns total attempts per country, ordered by count DESC.
|
||||
GetCountryStats(ctx context.Context) ([]CountryCount, error)
|
||||
|
||||
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
|
||||
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
|
||||
|
||||
// GetFilteredTopUsernames returns top usernames with optional filters applied.
|
||||
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopPasswords returns top passwords with optional filters applied.
|
||||
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopIPs returns top IPs with optional filters applied.
|
||||
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopCountries returns top countries with optional filters applied.
|
||||
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -361,6 +361,289 @@ func TestCloseActiveSessions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetExecCommand(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("set and retrieve", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Initially nil.
|
||||
s, err := store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s.ExecCommand != nil {
|
||||
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
|
||||
}
|
||||
|
||||
// Set exec command.
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("SetExecCommand: %v", err)
|
||||
}
|
||||
|
||||
s, err = store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s.ExecCommand == nil {
|
||||
t.Fatal("expected non-nil ExecCommand")
|
||||
}
|
||||
if *s.ExecCommand != "uname -a" {
|
||||
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("appears in recent sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
|
||||
t.Fatalf("SetExecCommand: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
|
||||
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func seedChartData(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
// Record attempts with country data from different IPs.
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttemptsOverTime(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||
}
|
||||
if len(points) != 0 {
|
||||
t.Errorf("expected empty, got %v", points)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||
}
|
||||
// All data was inserted today, so should be one point.
|
||||
if len(points) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(points))
|
||||
}
|
||||
// 5 + 3 + 2 = 10 total.
|
||||
if points[0].Count != 10 {
|
||||
t.Errorf("count = %d, want 10", points[0].Count)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetHourlyPattern(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetHourlyPattern: %v", err)
|
||||
}
|
||||
if len(counts) != 0 {
|
||||
t.Errorf("expected empty, got %v", counts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetHourlyPattern: %v", err)
|
||||
}
|
||||
// All data was inserted at the same hour.
|
||||
if len(counts) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(counts))
|
||||
}
|
||||
if counts[0].Count != 10 {
|
||||
t.Errorf("count = %d, want 10", counts[0].Count)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCountryStats(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
counts, err := store.GetCountryStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 0 {
|
||||
t.Errorf("expected empty, got %v", counts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
counts, err := store.GetCountryStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(counts))
|
||||
}
|
||||
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
|
||||
if counts[0].Country != "CN" || counts[0].Count != 7 {
|
||||
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
|
||||
}
|
||||
if counts[1].Country != "RU" || counts[1].Count != 3 {
|
||||
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("excludes empty country", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
|
||||
counts, err := store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(counts))
|
||||
}
|
||||
if counts[0].Country != "US" {
|
||||
t.Errorf("country = %q, want US", counts[0].Country)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredDashboardStats(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("no filter", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 10 {
|
||||
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by country", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
// CN: 5 + 2 = 7
|
||||
if stats.TotalAttempts != 7 {
|
||||
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by IP", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 5 {
|
||||
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by username", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 3 {
|
||||
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredTopUsernames(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
// Filter by country CN should only show root.
|
||||
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredTopUsernames: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(entries))
|
||||
}
|
||||
if entries[0].Value != "root" || entries[0].Count != 7 {
|
||||
t.Errorf("entries[0] = %+v, want root/7", entries[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRecentSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
@@ -417,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestInputBytes(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("counts only input direction", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
|
||||
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
|
||||
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
|
||||
if sessions[0].InputBytes != 7 {
|
||||
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero when no events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].InputBytes != 0 {
|
||||
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("filter by human score", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two sessions, one with human score > 0.
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by input bytes", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Session with more input (created first).
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
|
||||
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Session with less input (created after id1, so would be first by connected_at).
|
||||
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Default sort (connected_at DESC) should show id2 first.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id2 {
|
||||
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
|
||||
}
|
||||
|
||||
// Sort by input_bytes should show id1 first (more input).
|
||||
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("combined filters", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Different country, also has score.
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Same country CN but no score.
|
||||
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Filter: CN + human score > 0 -> only id1.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
|
||||
Country: "CN",
|
||||
HumanScoreAboveZero: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// dbContext returns a context detached from the HTTP request lifecycle with a
|
||||
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
|
||||
// queries when the browser aborts the previous XHR.
|
||||
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
}
|
||||
|
||||
type dashboardData struct {
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
TopPasswords []storage.TopEntry
|
||||
TopIPs []storage.TopEntry
|
||||
TopCountries []storage.TopEntry
|
||||
ActiveSessions []storage.Session
|
||||
RecentSessions []storage.Session
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
TopPasswords []storage.TopEntry
|
||||
TopIPs []storage.TopEntry
|
||||
TopCountries []storage.TopEntry
|
||||
TopExecCommands []storage.TopEntry
|
||||
ActiveSessions []storage.Session
|
||||
RecentSessions []storage.Session
|
||||
}
|
||||
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
@@ -56,6 +68,13 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get top exec commands", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
@@ -71,13 +90,14 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
data := dashboardData{
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
ActiveSessions: activeSessions,
|
||||
RecentSessions: recentSessions,
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
TopExecCommands: topExecCommands,
|
||||
ActiveSessions: activeSessions,
|
||||
RecentSessions: recentSessions,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
@@ -87,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := s.store.GetDashboardStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -101,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
|
||||
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -114,6 +140,24 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
f := parseDashboardFilter(r)
|
||||
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
|
||||
s.logger.Error("failed to render recent sessions fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type sessionDetailData struct {
|
||||
Session *storage.Session
|
||||
Logs []storage.SessionLog
|
||||
@@ -121,7 +165,8 @@ type sessionDetailData struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
session, err := s.store.GetSession(ctx, sessionID)
|
||||
@@ -171,8 +216,201 @@ type apiEventsResponse struct {
|
||||
Events []apiEvent `json:"events"`
|
||||
}
|
||||
|
||||
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
|
||||
func parseDateParam(r *http.Request, name string) *time.Time {
|
||||
v := r.URL.Query().Get(name)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
t, err := time.Parse("2006-01-02", v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// For "until" dates, set to end of day.
|
||||
if name == "until" {
|
||||
t = t.Add(24*time.Hour - time.Second)
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
|
||||
return storage.DashboardFilter{
|
||||
Since: parseDateParam(r, "since"),
|
||||
Until: parseDateParam(r, "until"),
|
||||
IP: r.URL.Query().Get("ip"),
|
||||
Country: r.URL.Query().Get("country"),
|
||||
Username: r.URL.Query().Get("username"),
|
||||
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
|
||||
SortBy: r.URL.Query().Get("sort"),
|
||||
}
|
||||
}
|
||||
|
||||
type apiTimeSeriesPoint struct {
|
||||
Date string `json:"date"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiAttemptsOverTimeResponse struct {
|
||||
Points []apiTimeSeriesPoint `json:"points"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
|
||||
days := 30
|
||||
if v := r.URL.Query().Get("days"); v != "" {
|
||||
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
|
||||
days = d
|
||||
}
|
||||
}
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get attempts over time", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
|
||||
for i, p := range points {
|
||||
resp.Points[i] = apiTimeSeriesPoint{
|
||||
Date: p.Timestamp.Format("2006-01-02"),
|
||||
Count: p.Count,
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode attempts over time", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiHourlyCount struct {
|
||||
Hour int `json:"hour"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiHourlyPatternResponse struct {
|
||||
Hours []apiHourlyCount `json:"hours"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
counts, err := s.store.GetHourlyPattern(ctx, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get hourly pattern", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
|
||||
for i, c := range counts {
|
||||
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode hourly pattern", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiCountryCount struct {
|
||||
Country string `json:"country"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiCountryStatsResponse struct {
|
||||
Countries []apiCountryCount `json:"countries"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
counts, err := s.store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get country stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
|
||||
for i, c := range counts {
|
||||
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode country stats", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
f := parseDashboardFilter(r)
|
||||
|
||||
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top usernames", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top passwords", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top IPs", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top countries", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := dashboardData{
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
|
||||
s.logger.Error("failed to render dashboard content fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||
|
||||
14
internal/web/static/chart.min.js
vendored
Normal file
14
internal/web/static/chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
275
internal/web/static/dashboard.js
Normal file
275
internal/web/static/dashboard.js
Normal file
@@ -0,0 +1,275 @@
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
// Chart.js theme for Pico dark mode
|
||||
Chart.defaults.color = '#b0b0b8';
|
||||
Chart.defaults.borderColor = '#3a3a4a';
|
||||
|
||||
var attemptsChart = null;
|
||||
var hourlyChart = null;
|
||||
|
||||
function getFilterParams() {
|
||||
var form = document.getElementById('filter-form');
|
||||
if (!form) return '';
|
||||
var params = new URLSearchParams();
|
||||
var since = form.elements['since'].value;
|
||||
var until = form.elements['until'].value;
|
||||
if (since) params.set('since', since);
|
||||
if (until) params.set('until', until);
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
return params.toString();
|
||||
}
|
||||
|
||||
function initAttemptsChart() {
|
||||
var canvas = document.getElementById('chart-attempts');
|
||||
if (!canvas) return;
|
||||
var ctx = canvas.getContext('2d');
|
||||
|
||||
var qs = getFilterParams();
|
||||
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
|
||||
|
||||
fetch(url)
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
var labels = data.points.map(function(p) { return p.date; });
|
||||
var values = data.points.map(function(p) { return p.count; });
|
||||
|
||||
if (attemptsChart) {
|
||||
attemptsChart.data.labels = labels;
|
||||
attemptsChart.data.datasets[0].data = values;
|
||||
attemptsChart.update();
|
||||
return;
|
||||
}
|
||||
|
||||
attemptsChart = new Chart(ctx, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: labels,
|
||||
datasets: [{
|
||||
label: 'Attempts',
|
||||
data: values,
|
||||
borderColor: '#6366f1',
|
||||
backgroundColor: 'rgba(99, 102, 241, 0.1)',
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
pointRadius: 2
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: true,
|
||||
plugins: { legend: { display: false } },
|
||||
scales: {
|
||||
x: { grid: { display: false } },
|
||||
y: { beginAtZero: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function initHourlyChart() {
|
||||
var canvas = document.getElementById('chart-hourly');
|
||||
if (!canvas) return;
|
||||
var ctx = canvas.getContext('2d');
|
||||
|
||||
var qs = getFilterParams();
|
||||
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
|
||||
|
||||
fetch(url)
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
// Fill all 24 hours, defaulting to 0
|
||||
var hourMap = {};
|
||||
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
|
||||
var labels = [];
|
||||
var values = [];
|
||||
for (var i = 0; i < 24; i++) {
|
||||
labels.push(i + ':00');
|
||||
values.push(hourMap[i] || 0);
|
||||
}
|
||||
|
||||
if (hourlyChart) {
|
||||
hourlyChart.data.labels = labels;
|
||||
hourlyChart.data.datasets[0].data = values;
|
||||
hourlyChart.update();
|
||||
return;
|
||||
}
|
||||
|
||||
hourlyChart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: labels,
|
||||
datasets: [{
|
||||
label: 'Attempts',
|
||||
data: values,
|
||||
backgroundColor: 'rgba(99, 102, 241, 0.6)',
|
||||
borderColor: '#6366f1',
|
||||
borderWidth: 1
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: true,
|
||||
plugins: { legend: { display: false } },
|
||||
scales: {
|
||||
x: { grid: { display: false } },
|
||||
y: { beginAtZero: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function initWorldMap() {
|
||||
var container = document.getElementById('world-map');
|
||||
if (!container) return;
|
||||
|
||||
fetch('/static/world.svg')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(svgText) {
|
||||
container.innerHTML = svgText;
|
||||
|
||||
fetch('/api/charts/country-stats')
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
colorMap(container, data.countries);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function colorMap(container, countries) {
|
||||
if (!countries || countries.length === 0) return;
|
||||
|
||||
var maxCount = countries[0].count; // already sorted DESC
|
||||
var logMax = Math.log(maxCount + 1);
|
||||
|
||||
// Build lookup
|
||||
var lookup = {};
|
||||
countries.forEach(function(c) {
|
||||
lookup[c.country.toLowerCase()] = c.count;
|
||||
});
|
||||
|
||||
// Create tooltip element
|
||||
var tooltip = document.createElement('div');
|
||||
tooltip.id = 'map-tooltip';
|
||||
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
|
||||
document.body.appendChild(tooltip);
|
||||
|
||||
var svg = container.querySelector('svg');
|
||||
if (!svg) return;
|
||||
|
||||
// Remove SVG title to prevent browser native tooltip
|
||||
var svgTitle = svg.querySelector('title');
|
||||
if (svgTitle) svgTitle.remove();
|
||||
|
||||
// Select both <path id="xx"> and <g id="xx"> country elements
|
||||
var elements = svg.querySelectorAll('path[id], g[id]');
|
||||
elements.forEach(function(el) {
|
||||
var id = el.id.toLowerCase();
|
||||
if (id.charAt(0) === '_') return; // skip non-country paths
|
||||
|
||||
var count = lookup[id];
|
||||
if (count) {
|
||||
var intensity = Math.log(count + 1) / logMax;
|
||||
var r = Math.round(30 + intensity * 69); // 30 -> 99
|
||||
var g = Math.round(30 + intensity * 72); // 30 -> 102
|
||||
var b = Math.round(62 + intensity * 179); // 62 -> 241
|
||||
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
|
||||
// For <g> elements, color child paths; for <path>, color directly
|
||||
if (el.tagName.toLowerCase() === 'g') {
|
||||
el.querySelectorAll('path').forEach(function(p) {
|
||||
p.style.fill = color;
|
||||
});
|
||||
} else {
|
||||
el.style.fill = color;
|
||||
}
|
||||
}
|
||||
|
||||
el.addEventListener('mouseenter', function(e) {
|
||||
var cc = id.toUpperCase();
|
||||
var n = lookup[id] || 0;
|
||||
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
|
||||
tooltip.style.display = 'block';
|
||||
});
|
||||
|
||||
el.addEventListener('mousemove', function(e) {
|
||||
tooltip.style.left = (e.clientX + 12) + 'px';
|
||||
tooltip.style.top = (e.clientY - 10) + 'px';
|
||||
});
|
||||
|
||||
el.addEventListener('mouseleave', function() {
|
||||
tooltip.style.display = 'none';
|
||||
});
|
||||
|
||||
el.addEventListener('click', function() {
|
||||
var input = document.querySelector('#filter-form input[name="country"]');
|
||||
if (input) {
|
||||
input.value = id.toUpperCase();
|
||||
applyFilters();
|
||||
}
|
||||
});
|
||||
|
||||
el.style.cursor = 'pointer';
|
||||
});
|
||||
}
|
||||
|
||||
function applyFilters() {
|
||||
// Re-fetch charts with filter params
|
||||
initAttemptsChart();
|
||||
initHourlyChart();
|
||||
|
||||
// Re-fetch dashboard content via htmx
|
||||
var form = document.getElementById('filter-form');
|
||||
if (!form) return;
|
||||
|
||||
var params = new URLSearchParams();
|
||||
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
|
||||
var val = form.elements[name].value;
|
||||
if (val) params.set(name, val);
|
||||
});
|
||||
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
|
||||
var target = document.getElementById('dashboard-content');
|
||||
if (target) {
|
||||
var url = '/fragments/dashboard-content?' + params.toString();
|
||||
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
// Server-side filter for recent sessions table
|
||||
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
|
||||
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
window.clearFilters = function() {
|
||||
var form = document.getElementById('filter-form');
|
||||
if (form) {
|
||||
form.reset();
|
||||
applyFilters();
|
||||
}
|
||||
};
|
||||
|
||||
window.applyFilters = applyFilters;
|
||||
|
||||
// Initialize on DOM ready
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
initAttemptsChart();
|
||||
initHourlyChart();
|
||||
initWorldMap();
|
||||
|
||||
var form = document.getElementById('filter-form');
|
||||
if (form) {
|
||||
form.addEventListener('submit', function(e) {
|
||||
e.preventDefault();
|
||||
applyFilters();
|
||||
});
|
||||
}
|
||||
});
|
||||
})();
|
||||
1
internal/web/static/world.svg
Normal file
1
internal/web/static/world.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 55 KiB |
@@ -44,6 +44,32 @@ func templateFuncMap() template.FuncMap {
|
||||
}
|
||||
return fmt.Sprintf("%.0f%%", *f*100)
|
||||
},
|
||||
"derefString": func(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
},
|
||||
"truncateCommand": func(s string) string {
|
||||
if len(s) > 50 {
|
||||
return s[:50] + "..."
|
||||
}
|
||||
return s
|
||||
},
|
||||
"formatBytes": func(b int64) string {
|
||||
const (
|
||||
kb = 1024
|
||||
mb = 1024 * kb
|
||||
)
|
||||
switch {
|
||||
case b >= mb:
|
||||
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
|
||||
case b >= kb:
|
||||
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +81,7 @@ func loadTemplates() (*templateSet, error) {
|
||||
"templates/dashboard.html",
|
||||
"templates/fragments/stats.html",
|
||||
"templates/fragments/active_sessions.html",
|
||||
"templates/fragments/recent_sessions.html",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing dashboard templates: %w", err)
|
||||
|
||||
@@ -3,6 +3,86 @@
|
||||
{{template "stats" .Stats}}
|
||||
</section>
|
||||
|
||||
<details>
|
||||
<summary>Filters</summary>
|
||||
<form id="filter-form">
|
||||
<div class="grid">
|
||||
<label>Since <input type="date" name="since"></label>
|
||||
<label>Until <input type="date" name="until"></label>
|
||||
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
|
||||
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
|
||||
<label>Username <input type="text" name="username" placeholder="root"></label>
|
||||
</div>
|
||||
<div class="grid">
|
||||
<label><input type="checkbox" name="human_score" value="1"> Human score > 0</label>
|
||||
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
|
||||
</div>
|
||||
<button type="submit">Apply</button>
|
||||
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
|
||||
</form>
|
||||
</details>
|
||||
|
||||
<section>
|
||||
<h3>Attack Trends</h3>
|
||||
<div class="grid">
|
||||
<article>
|
||||
<header>Attempts Over Time</header>
|
||||
<canvas id="chart-attempts"></canvas>
|
||||
</article>
|
||||
<article>
|
||||
<header>Hourly Pattern (UTC)</header>
|
||||
<canvas id="chart-hourly"></canvas>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Attack Origins</h3>
|
||||
<article>
|
||||
<div id="world-map"></div>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<div id="dashboard-content">
|
||||
{{template "dashboard_content" .}}
|
||||
</div>
|
||||
|
||||
<section>
|
||||
<h3>Active Sessions</h3>
|
||||
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||
{{template "active_sessions" .ActiveSessions}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Recent Sessions</h3>
|
||||
<table id="recent-sessions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{template "recent_sessions" .RecentSessions}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
{{define "scripts"}}
|
||||
<script src="/static/chart.min.js"></script>
|
||||
<script src="/static/dashboard.js"></script>
|
||||
{{end}}
|
||||
|
||||
{{define "dashboard_content"}}
|
||||
<section>
|
||||
<h3>Top Credentials & IPs</h3>
|
||||
<div class="top-grid">
|
||||
@@ -66,47 +146,21 @@
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
<article>
|
||||
<header>Top Exec Commands</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Command</th><th>Count</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopExecCommands}}
|
||||
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="2">No data</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Active Sessions</h3>
|
||||
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||
{{template "active_sessions" .ActiveSessions}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Recent Sessions</h3>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Shell</th>
|
||||
<th>Score</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .RecentSessions}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.ShellName}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="8">No sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
@@ -6,24 +6,26 @@
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Shell</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.ShellName}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="7">No active sessions</td></tr>
|
||||
<tr><td colspan="8">No active sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
17
internal/web/templates/fragments/recent_sessions.html
Normal file
17
internal/web/templates/fragments/recent_sessions.html
Normal file
@@ -0,0 +1,17 @@
|
||||
{{define "recent_sessions"}}
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="9">No sessions</td></tr>
|
||||
{{end}}
|
||||
{{end}}
|
||||
@@ -29,9 +29,16 @@
|
||||
}
|
||||
.top-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
|
||||
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
.top-grid article {
|
||||
overflow: hidden;
|
||||
min-width: 0;
|
||||
}
|
||||
#world-map svg { width: 100%; height: auto; }
|
||||
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
|
||||
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
|
||||
nav h1 {
|
||||
margin: 0;
|
||||
}
|
||||
@@ -52,5 +59,6 @@
|
||||
<main class="container">
|
||||
{{block "content" .}}{{end}}
|
||||
</main>
|
||||
{{block "scripts" .}}{{end}}
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
<tr><td><strong>Country</strong></td><td>{{.Session.Country}}</td></tr>
|
||||
<tr><td><strong>Username</strong></td><td>{{.Session.Username}}</td></tr>
|
||||
<tr><td><strong>Shell</strong></td><td>{{.Session.ShellName}}</td></tr>
|
||||
{{if .Session.ExecCommand}}<tr><td><strong>Exec Command</strong></td><td><code>{{derefString .Session.ExecCommand}}</code></td></tr>{{end}}
|
||||
<tr><td><strong>Score</strong></td><td>{{formatScore .Session.HumanScore}}</td></tr>
|
||||
<tr><td><strong>Connected</strong></td><td>{{formatTime .Session.ConnectedAt}}</td></tr>
|
||||
<tr>
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
@@ -40,9 +40,14 @@ func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Han
|
||||
s.mux.Handle("GET /static/", http.FileServerFS(staticFS))
|
||||
s.mux.HandleFunc("GET /sessions/{id}", s.handleSessionDetail)
|
||||
s.mux.HandleFunc("GET /api/sessions/{id}/events", s.handleAPISessionEvents)
|
||||
s.mux.HandleFunc("GET /api/charts/attempts-over-time", s.handleAPIAttemptsOverTime)
|
||||
s.mux.HandleFunc("GET /api/charts/hourly-pattern", s.handleAPIHourlyPattern)
|
||||
s.mux.HandleFunc("GET /api/charts/country-stats", s.handleAPICountryStats)
|
||||
s.mux.HandleFunc("GET /", s.handleDashboard)
|
||||
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
|
||||
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
|
||||
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
|
||||
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
|
||||
|
||||
if metricsHandler != nil {
|
||||
h := metricsHandler
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
@@ -54,6 +54,30 @@ func newSeededTestServer(t *testing.T) *Server {
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestDbContextNotCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
dbCtx, dbCancel := dbContext(req)
|
||||
defer dbCancel()
|
||||
|
||||
// Cancel the original request context.
|
||||
cancel()
|
||||
|
||||
// The DB context should still be usable.
|
||||
select {
|
||||
case <-dbCtx.Done():
|
||||
t.Fatal("dbContext should not be canceled when request context is canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
// Verify the DB context has a deadline (from the timeout).
|
||||
if _, ok := dbCtx.Deadline(); !ok {
|
||||
t.Error("dbContext should have a deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardHandler(t *testing.T) {
|
||||
t.Run("empty store", func(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
@@ -340,6 +364,190 @@ func TestMetricsBearerToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestTruncateCommand(t *testing.T) {
|
||||
funcMap := templateFuncMap()
|
||||
fn := funcMap["truncateCommand"].(func(string) string)
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"short", "short"},
|
||||
{"exactly fifty characters long! that is what it i.", "exactly fifty characters long! that is what it i."},
|
||||
{"this string is definitely longer than fifty characters and should be truncated", "this string is definitely longer than fifty charac..."},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := fn(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("truncateCommand(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardExecCommands(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "Top Exec Commands") {
|
||||
t.Error("response should contain 'Top Exec Commands'")
|
||||
}
|
||||
if !strings.Contains(body, "uname -a") {
|
||||
t.Error("response should contain exec command 'uname -a'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIAttemptsOverTime(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(ct, "application/json") {
|
||||
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||
}
|
||||
|
||||
var resp apiAttemptsOverTimeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
// Seeded data inserted today -> at least 1 point.
|
||||
if len(resp.Points) == 0 {
|
||||
t.Error("expected at least one data point")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIHourlyPattern(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp apiHourlyPatternResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Hours) == 0 {
|
||||
t.Error("expected at least one hourly data point")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPICountryStats(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp apiCountryStatsResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Countries) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(resp.Countries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragmentDashboardContent(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if strings.Contains(body, "<!DOCTYPE html>") {
|
||||
t.Error("dashboard content fragment should not contain full HTML document")
|
||||
}
|
||||
if !strings.Contains(body, "Top Usernames") {
|
||||
t.Error("dashboard content fragment should contain 'Top Usernames'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragmentDashboardContentWithFilter(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
// When filtered by CN, should show root but not admin.
|
||||
if !strings.Contains(body, "root") {
|
||||
t.Error("response should contain 'root' when filtered by CN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticAssets(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
|
||||
@@ -349,6 +557,9 @@ func TestStaticAssets(t *testing.T) {
|
||||
}{
|
||||
{"/static/pico.min.css", "text/css"},
|
||||
{"/static/htmx.min.js", "text/javascript"},
|
||||
{"/static/chart.min.js", "text/javascript"},
|
||||
{"/static/dashboard.js", "text/javascript"},
|
||||
{"/static/world.svg", "image/svg+xml"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -34,6 +34,16 @@ password = "admin"
|
||||
# password = "cisco"
|
||||
# shell = "cisco"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "irobot"
|
||||
# password = "roomba"
|
||||
# shell = "roomba"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "player"
|
||||
# password = "tetris"
|
||||
# shell = "tetris"
|
||||
|
||||
[storage]
|
||||
db_path = "oubliette.db"
|
||||
retention_days = 90
|
||||
@@ -50,6 +60,12 @@ hostname = "ubuntu-server"
|
||||
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||
# fake_user = "" # override username in prompt; empty = use authenticated user
|
||||
|
||||
# Map usernames to specific shells (regardless of how auth succeeded).
|
||||
# Credential-specific shell overrides take priority over username routes.
|
||||
# [shell.username_routes]
|
||||
# postgres = "psql"
|
||||
# admin = "bash"
|
||||
|
||||
# Per-shell configuration (optional).
|
||||
# [shell.banking]
|
||||
# bank_name = "SECUREBANK"
|
||||
@@ -65,6 +81,16 @@ hostname = "ubuntu-server"
|
||||
# ios_version = "15.0(2)SE11"
|
||||
# enable_password = "" # empty = accept after 1 failed attempt
|
||||
|
||||
# [shell.psql]
|
||||
# db_name = "postgres"
|
||||
# pg_version = "15.4"
|
||||
|
||||
# [shell.roomba]
|
||||
# No configuration options currently.
|
||||
|
||||
# [shell.tetris]
|
||||
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
|
||||
|
||||
# [detection]
|
||||
# enabled = true
|
||||
# threshold = 0.6 # 0.0–1.0, sessions above this trigger notifications
|
||||
|
||||
Reference in New Issue
Block a user